diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..0df37b1230 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,23 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" + groups: + github-actions: + patterns: + - "*" + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "monthly" + groups: + python: + patterns: + - "*" diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 64a41ca16e..16916c8fbd 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -15,9 +15,9 @@ jobs: name: Test the code with Keras 2 runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.9 - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: 3.9 - name: Get pip cache dir @@ -26,7 +26,7 @@ jobs: python -m pip install --upgrade pip setuptools echo "::set-output name=dir::$(pip cache dir)" - name: pip cache - uses: actions/cache@v2 + uses: actions/cache@v4 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} @@ -51,9 +51,9 @@ jobs: backend: [tensorflow, jax, torch] runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.9 - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: 3.9 - name: Get pip cache dir @@ -62,7 +62,7 @@ jobs: python -m pip install --upgrade pip setuptools echo "::set-output name=dir::$(pip cache dir)" - name: pip cache - uses: actions/cache@v2 + uses: actions/cache@v4 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} @@ -81,9 +81,9 @@ jobs: name: Check the code format runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.9 - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: 3.9 - name: Get pip cache dir @@ -92,7 +92,7 @@ jobs: python -m pip install --upgrade pip setuptools echo "::set-output name=dir::$(pip cache dir)" - name: pip cache - uses: actions/cache@v2 + uses: actions/cache@v4 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 677a641658..686421fe00 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -16,9 +16,9 @@ jobs: needs: [run-test-for-nightly] runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.9 - name: Get pip cache dir @@ -27,7 +27,7 @@ jobs: python -m pip install --upgrade pip setuptools echo "::set-output name=dir::$(pip cache dir)" - name: pip cache - uses: actions/cache@v2 + uses: actions/cache@v4 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml index c3f6767350..186b7668da 100644 --- a/.github/workflows/publish-to-pypi.yml +++ b/.github/workflows/publish-to-pypi.yml @@ -10,9 +10,9 @@ jobs: name: Build and publish to PyPI runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.9 - name: Get pip cache dir @@ -21,7 +21,7 @@ jobs: python -m pip install --upgrade pip setuptools echo "::set-output name=dir::$(pip cache dir)" - name: pip cache - uses: actions/cache@v2 + uses: actions/cache@v4 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py new file mode 100644 index 0000000000..8a66ad05af --- /dev/null +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -0,0 +1,137 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops + + +@keras_nlp_export("keras_nlp.layers.AlibiBias") +class AlibiBias(keras.layers.Layer): + """A layer that adds the alibi bias to attention scores. + + This layer adds the alibi bias to the attention scores. Alibi bias is a + linear, non-learned bias. Defined and formalized in + [Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation](https://arxiv.org/abs/2108.12409). + + This layer takes as input the attention scores. and returns the attention + scores after adding the alibi bias to it. The output will have the same + shape as the input. + + Args: + alibi_bias_max: int. This value will be used to compute the slope of + each head. The heads' slopes are a geometric sequence that starts at + `2**(-alibi_bias_max/num_heads)` and uses that same value as its + ratio. Defaults to 8. + Call arguments: + attention_scores: The result of multipying the query and the key of the + multi-head attention layer of the transformer to add alibi bias to + it. With shape `(batch_size, num_heads, query_length, key_length)`. + + Examples: + ```python + query_length = 10 + key_length = 10 + num_heads = 4 + batch_size = 2 + hidden_dim = 8 + + # Create new alibi layer. + alibi_layer = keras_nlp.layers.AlibiBias() + + query = np.zeros((batch_size, num_heads, query_length, hidden_dim)) + key = np.zeros((batch_size, num_heads, hidden_dim, key_length)) + + attention_scores = keras.ops.matmul(query, key) + + # Add alibi bias to attention scores. + attention_scores = alibi_layer(attention_scores) + ``` + + References: + - [Press et al., 2021](https://arxiv.org/abs/2108.12409) + """ + + def __init__( + self, + alibi_bias_max=8, + **kwargs, + ): + super().__init__(**kwargs) + self.alibi_bias_max = alibi_bias_max + + def call(self, attention_scores): + shape = ops.shape(attention_scores) + if len(shape) != 4: + raise ValueError( + "Expected `attention_scores` shape to be " + "`(batch_size, num_heads, query_length, key_Length)`." + f" Recived shape={shape}" + ) + + key_length = shape[-1] + num_heads = shape[-3] + + alibi_bias = self._get_alibi_bias(num_heads, key_length) + + return ops.add(attention_scores, alibi_bias) + + def _get_alibi_bias(self, num_heads, key_length): + slopes = ops.convert_to_tensor( + self._get_slopes(num_heads), dtype=self.compute_dtype + ) + slopes = ops.expand_dims(slopes, 1) + + seq_range = ops.expand_dims(ops.arange(1 - key_length, 1), 0) + seq_range = ops.cast(seq_range, dtype=self.compute_dtype) + + alibi_bias = ops.multiply(slopes, seq_range) + alibi_bias = ops.expand_dims(alibi_bias, 1) + + # return shape is `(1, num_heads, 1, key_length)` + return ops.expand_dims(alibi_bias, 0) + + def _get_slopes(self, num_heads): + # this function is adopted from Alibi original implementation. + # https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + def get_slopes_power_of_2(n): + start = 2 ** ( + -(2 ** -(math.log2(n) - math.log2(self.alibi_bias_max))) + ) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(num_heads).is_integer(): + return get_slopes_power_of_2(num_heads) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + self._get_slopes(2 * closest_power_of_2)[0::2][ + : num_heads - closest_power_of_2 + ] + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "alibi_bias_max": self.alibi_bias_max, + } + ) + return config diff --git a/keras_nlp/layers/modeling/alibi_bias_test.py b/keras_nlp/layers/modeling/alibi_bias_test.py new file mode 100644 index 0000000000..120c7622b1 --- /dev/null +++ b/keras_nlp/layers/modeling/alibi_bias_test.py @@ -0,0 +1,202 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.backend import random +from keras_nlp.layers.modeling.alibi_bias import AlibiBias +from keras_nlp.tests.test_case import TestCase + + +class AlibiBiasTest(TestCase): + def test_layer_behaviors(self): + alibi_bias_max = 8 + batch_size = 4 + num_heads = 8 + query_length = 10 + key_length = 10 + self.run_layer_test( + cls=AlibiBias, + init_kwargs={ + "alibi_bias_max": alibi_bias_max, + }, + input_data=random.uniform( + shape=(batch_size, num_heads, query_length, key_length) + ), + expected_output_shape=( + batch_size, + num_heads, + query_length, + key_length, + ), + ) + + def test_float16_dtype(self): + # Create a 4-dimensional input (the first dimension is implicit). + alibi_bias_max = 8 + num_heads = 8 + query_length = 5 + key_length = 10 + test_layer = AlibiBias(alibi_bias_max=alibi_bias_max, dtype="float16") + input_tensor = keras.Input(shape=(num_heads, query_length, key_length)) + output_tensor = test_layer(input_tensor) + + # the output is expected to be the same as the input shape in all + # dimensions. here, the first dimension is implicit and is for batch + expected_output_shape = (None, num_heads, query_length, key_length) + self.assertEqual(expected_output_shape, output_tensor.shape) + # The default output dtype for this layer should be "float32". + self.assertEqual("float16", output_tensor.dtype) + + def test_dynamic_layer_output_shape(self): + query_length = 10 + key_length = 10 + num_heads = 4 + + test_layer = AlibiBias() + # Create a 4-dimensional input (the first dimension is implicit). + input_tensor = keras.Input(shape=(num_heads, query_length, key_length)) + output_tensor = test_layer(input_tensor) + + # the output is expected to be the same as the input shape in all + # dimensions. + expected_output_shape = ( + None, + num_heads, + query_length, + key_length, + ) + self.assertEqual(expected_output_shape, output_tensor.shape) + + def test_value_error_when_inputs_shape_is_not_4(self): + with self.assertRaises(ValueError): + AlibiBias()(random.uniform(shape=(12, 12))) + + def test_num_heads_is_not_power_of_two(self): + inputs_shape = (1, 12, 12, 12) + inputs = random.uniform(shape=inputs_shape) + layer = AlibiBias() + outputs = layer(inputs) + self.assertEqual(inputs_shape, outputs.shape) + + def test_correct_output(self): + batch_size = 1 + num_heads = 8 + query_length = 1 + key_length = 3 + input_shape = (batch_size, num_heads, query_length, key_length) + input_tensor = ops.zeros(input_shape) + layer = AlibiBias() + output_tensor = layer(input_tensor) + print(output_tensor) + self.assertAllClose( + output_tensor, + ops.convert_to_tensor( + [ + [ + [[-1.0, -0.5, 0.0]], + [[-0.5, -0.25, 0.0]], + [[-0.25, -0.125, 0.0]], + [[-0.125, -0.0625, 0.0]], + [[-0.0625, -0.03125, 0.0]], + [[-0.03125, -0.015625, 0.0]], + [[-0.015625, -0.0078125, 0.0]], + [[-0.0078125, -0.00390625, 0.0]], + ] + ] + ), + ) + + def test_correct_output_num_heads_not_power_of_two(self): + batch_size = 1 + num_heads = 14 + query_length = 1 + key_length = 3 + input_shape = (batch_size, num_heads, query_length, key_length) + input_tensor = ops.zeros(input_shape) + layer = AlibiBias() + output_tensor = layer(input_tensor) + print(output_tensor) + self.assertAllClose( + output_tensor, + ops.convert_to_tensor( + [ + [ + [[-1.0, -0.5, 0.0]], + [[-0.5, -0.25, 0.0]], + [[-0.25, -0.125, 0.0]], + [[-0.125, -0.0625, 0.0]], + [[-0.0625, -0.03125, 0.0]], + [[-0.03125, -0.015625, 0.0]], + [[-0.015625, -0.0078125, 0.0]], + [[-0.0078125, -0.00390625, 0.0]], + [[-1.4142135, -0.70710677, 0.0]], + [[-0.70710677, -0.35355338, 0.0]], + [[-0.35355338, -0.17677669, 0.0]], + [[-0.17677669, -0.08838835, 0.0]], + [[-0.08838835, -0.04419417, 0.0]], + [[-0.04419417, -0.02209709, 0.0]], + ] + ] + ), + ) + + def test_correct_output_alibi_bias_max(self): + alibi_bias_max = 12 + batch_size = 1 + num_heads = 2 + query_length = 1 + key_length = 3 + input_shape = (batch_size, num_heads, query_length, key_length) + input_tensor = ops.zeros(input_shape) + layer = AlibiBias(alibi_bias_max=alibi_bias_max) + output_tensor = layer(input_tensor) + print(output_tensor) + self.assertAllClose( + output_tensor, + ops.convert_to_tensor( + [ + [ + [[-0.03125, -0.015625, 0.0]], + [[-0.00048828, -0.00024414, 0.0]], + ] + ] + ), + ) + + def test_correct_output_alibi_bias_max_num_heads_not_power_of_two( + self, + ): + alibi_bias_max = 6 + batch_size = 1 + num_heads = 3 + query_length = 1 + key_length = 3 + input_shape = (batch_size, num_heads, query_length, key_length) + input_tensor = ops.zeros(input_shape) + layer = AlibiBias(alibi_bias_max=alibi_bias_max) + output_tensor = layer(input_tensor) + print(output_tensor) + self.assertAllClose( + output_tensor, + ops.convert_to_tensor( + [ + [ + [[-0.25, -0.125, 0.0]], + [[-0.03125, -0.015625, 0.0]], + [[-0.70710677, -0.35355338, 0.0]], + ] + ] + ), + ) diff --git a/keras_nlp/layers/modeling/cached_multi_head_attention_test.py b/keras_nlp/layers/modeling/cached_multi_head_attention_test.py index 052ce66ec1..6bf4311423 100644 --- a/keras_nlp/layers/modeling/cached_multi_head_attention_test.py +++ b/keras_nlp/layers/modeling/cached_multi_head_attention_test.py @@ -28,6 +28,7 @@ def test_layer_behaviors(self): init_kwargs={ "num_heads": 2, "key_dim": 4, + "dropout": 0.1, }, input_data={ "query": random.uniform(shape=(2, 4, 6)), @@ -38,7 +39,7 @@ def test_layer_behaviors(self): expected_num_non_trainable_variables=1, # Keras 2 does not handle mixed precision correctly when not set # globally. - run_mixed_precision_check=config.keras_3(), + run_precision_checks=config.keras_3(), ) def test_cache_call_is_correct(self): diff --git a/keras_nlp/layers/modeling/f_net_encoder_test.py b/keras_nlp/layers/modeling/f_net_encoder_test.py index e5d0b1ea77..47c189ef11 100644 --- a/keras_nlp/layers/modeling/f_net_encoder_test.py +++ b/keras_nlp/layers/modeling/f_net_encoder_test.py @@ -23,11 +23,11 @@ def test_layer_behaviors(self): cls=FNetEncoder, init_kwargs={ "intermediate_dim": 4, - "dropout": 0, "activation": "relu", "layer_norm_epsilon": 1e-5, "kernel_initializer": "HeNormal", "bias_initializer": "Zeros", + "dropout": 0.1, }, input_data=random.uniform(shape=(2, 4, 6)), expected_output_shape=(2, 4, 6), diff --git a/keras_nlp/layers/modeling/masked_lm_head_test.py b/keras_nlp/layers/modeling/masked_lm_head_test.py index 8d22ea0343..69a6288911 100644 --- a/keras_nlp/layers/modeling/masked_lm_head_test.py +++ b/keras_nlp/layers/modeling/masked_lm_head_test.py @@ -58,6 +58,7 @@ def test_layer_behaviors_with_embedding(self): }, expected_output_shape=(4, 5, 100), expected_num_trainable_weights=6, + run_precision_checks=False, ) def test_value_error_when_neither_embedding_or_vocab_size_set(self): diff --git a/keras_nlp/layers/modeling/sine_position_encoding_test.py b/keras_nlp/layers/modeling/sine_position_encoding_test.py index 80dad26cbc..e7ea4b1b05 100644 --- a/keras_nlp/layers/modeling/sine_position_encoding_test.py +++ b/keras_nlp/layers/modeling/sine_position_encoding_test.py @@ -107,13 +107,3 @@ def test_start_index(self): sequential_output, (0, i, 0), parial_output ) self.assertAllClose(full_output, sequential_output) - - def test_float16_dtype(self): - pos_encoding = SinePositionEncoding(dtype="float16") - seq_length = 100 - hidden_size = 32 - inputs = keras.Input(shape=(seq_length, hidden_size)) - outputs = pos_encoding(inputs) - - # output dtype for this layer should be tf.float16. - self.assertEqual(outputs.dtype, "float16") diff --git a/keras_nlp/layers/modeling/transformer_decoder.py b/keras_nlp/layers/modeling/transformer_decoder.py index 3a3cda3f21..15c245768c 100644 --- a/keras_nlp/layers/modeling/transformer_decoder.py +++ b/keras_nlp/layers/modeling/transformer_decoder.py @@ -416,10 +416,10 @@ def call( cache=cross_attention_cache, cache_update_index=cross_attention_cache_update_index, ) - if self_attention_cache is None: + if cross_attention_cache is None: x = attention_output else: - x, self_attention_cache = attention_output + x, cross_attention_cache = attention_output x = self._cross_attention_dropout(x) x = x + residual if not self.normalize_first: @@ -469,9 +469,11 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index, + ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/layers/modeling/transformer_decoder_test.py b/keras_nlp/layers/modeling/transformer_decoder_test.py index 2b54324f02..aa16d9ae5a 100644 --- a/keras_nlp/layers/modeling/transformer_decoder_test.py +++ b/keras_nlp/layers/modeling/transformer_decoder_test.py @@ -36,6 +36,7 @@ def test_layer_behaviors(self, normalize_first): "layer_norm_epsilon": 1e-05, "kernel_initializer": "HeNormal", "bias_initializer": "Zeros", + "dropout": 0.1, }, input_data=random.uniform(shape=(2, 4, 6)), expected_output_shape=(2, 4, 6), @@ -48,7 +49,6 @@ def test_layer_behaviors(self, normalize_first): ("with_norm_first", True), ) def test_layer_behaviors_with_cross_attention(self, normalize_first): - pass self.run_layer_test( cls=TransformerDecoder, init_kwargs={ @@ -59,6 +59,7 @@ def test_layer_behaviors_with_cross_attention(self, normalize_first): "layer_norm_epsilon": 1e-05, "kernel_initializer": "HeNormal", "bias_initializer": "Zeros", + "dropout": 0.1, }, input_data={ "decoder_sequence": random.uniform(shape=(2, 4, 6)), @@ -126,10 +127,7 @@ def test_mask_propagation_without_cross_attention(self): self.assertAllEqual(outputs._keras_mask, mask) def test_cache_call_is_correct(self): - batch_size = 2 - seq_len = 5 - num_heads = 2 - key_dim = 4 + batch_size, seq_len, num_heads, key_dim = 2, 5, 2, 4 hidden_dim = num_heads * key_dim input_shape = (batch_size, seq_len, hidden_dim) @@ -170,6 +168,59 @@ def call(outputs, cache): self.assertAllClose(output, no_loop_outputs) self.assertAllClose(output_cache, no_loop_cache) + def test_cache_call_is_correct_with_cross_attention(self): + batch_size, seq_len, num_heads, key_dim = 2, 5, 2, 4 + hidden_dim = num_heads * key_dim + + input_shape = (batch_size, seq_len, hidden_dim) + cache_shape = (batch_size, 2, seq_len, num_heads, key_dim) + decoder_sequence = random.uniform(shape=input_shape) + encoder_sequence = random.uniform(shape=input_shape) + empty_cache = ops.zeros(cache_shape) + outputs = ops.zeros_like(decoder_sequence) + + layer = TransformerDecoder( + intermediate_dim=4, + num_heads=num_heads, + ) + no_loop_outputs, no_loop_self_cache, no_loop_cross_cache = layer( + decoder_sequence, + encoder_sequence, + self_attention_cache=empty_cache, + self_attention_cache_update_index=0, + cross_attention_cache=empty_cache, + cross_attention_cache_update_index=0, + ) + + def loop_body(i, outputs, self_cache, cross_cache): + # Compute the rest tokens. + start, size = (0, i, 0), (batch_size, 1, hidden_dim) + next_input = ops.slice(decoder_sequence, start, size) + next_output, self_cache, cross_cache = layer( + decoder_sequence=next_input, + encoder_sequence=encoder_sequence, + self_attention_cache=self_cache, + self_attention_cache_update_index=i, + cross_attention_cache=cross_cache, + ) + outputs = ops.slice_update(outputs, start, next_output) + return i + 1, outputs, self_cache, cross_cache + + def call(outputs, self_cache, cross_cache): + _, outputs, self_cache, cross_cache = ops.while_loop( + cond=lambda i, outputs, self_cache, cross_cache: i < seq_len, + body=loop_body, + loop_vars=[0, outputs, self_cache, cross_cache], + ) + return outputs, self_cache, cross_cache + + output, self_cache, cross_cache = call( + outputs, empty_cache, no_loop_cross_cache + ) + self.assertAllClose(output, no_loop_outputs) + self.assertAllClose(self_cache, no_loop_self_cache) + self.assertAllClose(cross_cache, no_loop_cross_cache) + def test_different_feature_dimension_for_encoder_and_decoder_sequence(self): decoder = TransformerDecoder( intermediate_dim=4, diff --git a/keras_nlp/layers/modeling/transformer_encoder_test.py b/keras_nlp/layers/modeling/transformer_encoder_test.py index 844125c4b0..edcfdfc470 100644 --- a/keras_nlp/layers/modeling/transformer_encoder_test.py +++ b/keras_nlp/layers/modeling/transformer_encoder_test.py @@ -37,6 +37,7 @@ def test_layer_behaviors(self, normalize_first): "layer_norm_epsilon": 1e-05, "kernel_initializer": "HeNormal", "bias_initializer": "Zeros", + "dropout": 0.1, }, input_data=random.uniform(shape=(2, 4, 6)), expected_output_shape=(2, 4, 6), diff --git a/keras_nlp/layers/modeling/transformer_layer_utils.py b/keras_nlp/layers/modeling/transformer_layer_utils.py index 863da59a36..f375bf1b9d 100644 --- a/keras_nlp/layers/modeling/transformer_layer_utils.py +++ b/keras_nlp/layers/modeling/transformer_layer_utils.py @@ -55,9 +55,12 @@ def compute_causal_mask(batch_size, input_length, output_length, cache_index=0): `(batch_size, output_length, input_length)` that can be passed to a attention layer. """ - i = ops.expand_dims(ops.arange(output_length), axis=1) + cache_index - j = ops.arange(input_length) - mask = ops.expand_dims(ops.cast(i >= j, dtype="int32"), axis=0) + i = ops.arange(output_length, dtype="float32") + i = i + ops.cast(cache_index, "float32") + i = ops.expand_dims(i, axis=1) + j = ops.arange(input_length, dtype="float32") + mask = ops.expand_dims(i >= j, axis=0) + return ops.broadcast_to(mask, (batch_size, output_length, input_length)) diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index ab04d8eae0..cdd50670f3 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -35,6 +35,8 @@ ) from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor from keras_nlp.models.bert.bert_tokenizer import BertTokenizer +from keras_nlp.models.bloom.bloom_backbone import BloomBackbone +from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone from keras_nlp.models.deberta_v3.deberta_v3_classifier import ( DebertaV3Classifier, @@ -73,6 +75,13 @@ ) from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.models.gemma.gemma_causal_lm import GemmaCausalLM +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_preprocessor import GemmaPreprocessor +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( @@ -91,6 +100,12 @@ from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.models.llama.llama_backbone import LlamaBackbone from keras_nlp.models.mistral.mistral_backbone import MistralBackbone +from keras_nlp.models.mistral.mistral_causal_lm import MistralCausalLM +from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( + MistralCausalLMPreprocessor, +) +from keras_nlp.models.mistral.mistral_preprocessor import MistralPreprocessor +from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer from keras_nlp.models.opt.opt_backbone import OPTBackbone from keras_nlp.models.opt.opt_causal_lm import OPTCausalLM from keras_nlp.models.opt.opt_causal_lm_preprocessor import ( diff --git a/keras_nlp/models/albert/albert_backbone.py b/keras_nlp/models/albert/albert_backbone.py index 414bb97e87..1e342e791c 100644 --- a/keras_nlp/models/albert/albert_backbone.py +++ b/keras_nlp/models/albert/albert_backbone.py @@ -72,6 +72,10 @@ class AlbertBackbone(Backbone): embeddings. num_segments: int. The number of types that the 'segment_ids' input can take. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. Examples: ```python @@ -110,6 +114,7 @@ def __init__( dropout=0.0, max_sequence_length=512, num_segments=2, + dtype=None, **kwargs, ): if num_layers % num_groups != 0: @@ -118,112 +123,108 @@ def __init__( f"`num_layers={num_layers}` and `num_groups={num_groups}`." ) - # Index of classification token in the vocabulary - cls_token_index = 0 - # Inputs - token_id_input = keras.Input( - shape=(None,), dtype="int32", name="token_ids" - ) - segment_id_input = keras.Input( - shape=(None,), dtype="int32", name="segment_ids" - ) - padding_mask = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" - ) - - # Embed tokens, positions, and segment ids. - token_embedding_layer = ReversibleEmbedding( + # === Layers === + self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=embedding_dim, embeddings_initializer=albert_kernel_initializer(), + dtype=dtype, name="token_embedding", ) - token_embedding = token_embedding_layer(token_id_input) - position_embedding = PositionEmbedding( + self.position_embedding = PositionEmbedding( initializer=albert_kernel_initializer(), sequence_length=max_sequence_length, + dtype=dtype, name="position_embedding", - )(token_embedding) - segment_embedding = keras.layers.Embedding( + ) + self.segment_embedding = keras.layers.Embedding( input_dim=num_segments, output_dim=embedding_dim, embeddings_initializer=albert_kernel_initializer(), + dtype=dtype, name="segment_embedding", - )(segment_id_input) - - # Sum, normalize and apply dropout to embeddings. - x = keras.layers.Add()( - (token_embedding, position_embedding, segment_embedding) ) - x = keras.layers.LayerNormalization( - name="embeddings_layer_norm", + self.embeddings_add = keras.layers.Add( + dtype=dtype, + name="embeddings_add", + ) + self.embeddings_layer_norm = keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, - dtype="float32", - )(x) - x = keras.layers.Dropout( + dtype=dtype, + name="embeddings_layer_norm", + ) + self.embeddings_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="embeddings_dropout", - )(x) - - # Project the embedding to `hidden_dim`. - x = keras.layers.Dense( + ) + self.embeddings_projection = keras.layers.Dense( hidden_dim, kernel_initializer=albert_kernel_initializer(), + dtype=dtype, name="embedding_projection", - )(x) - - def get_group_layer(group_idx): - """Defines a group `num_inner_repetitions` transformer layers and - returns the callable. - """ - transformer_layers = [ - TransformerEncoder( + ) + self.transformer_layers = [] + for group_idx in range(num_groups): + inner_layers = [] + for inner_idx in range(num_inner_repetitions): + layer = TransformerEncoder( num_heads=num_heads, intermediate_dim=intermediate_dim, activation=gelu_approximate, dropout=dropout, layer_norm_epsilon=1e-12, kernel_initializer=albert_kernel_initializer(), + dtype=dtype, name=f"group_{group_idx}_inner_layer_{inner_idx}", ) - for inner_idx in range(num_inner_repetitions) - ] - - def call(x, padding_mask): - for transformer_layer in transformer_layers: - x = transformer_layer(x, padding_mask=padding_mask) - return x - - return call - - num_calls_per_group = num_layers // num_groups - for group_idx in range(num_groups): - # Define the group. A group in ALBERT terminology is any number of - # repeated attention and FFN blocks. - group_layer = get_group_layer(group_idx) - - # Assume num_layers = 8, num_groups = 4. Then, the order of group - # calls will be 0, 0, 1, 1, 2, 2, 3, 3. - for call in range(num_calls_per_group): - x = group_layer(x, padding_mask=padding_mask) - - # Construct the two ALBERT outputs. The pooled output is a dense layer on - # top of the [CLS] token. - sequence_output = x - pooled_output = keras.layers.Dense( + inner_layers.append(layer) + self.transformer_layers.append(inner_layers) + self.pooled_dense = keras.layers.Dense( hidden_dim, kernel_initializer=albert_kernel_initializer(), activation="tanh", + dtype=dtype, name="pooled_dense", - )(x[:, cls_token_index, :]) + ) - # Instantiate using Functional API Model constructor + # === Functional Model === + # Inputs + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + segment_id_input = keras.Input( + shape=(None,), dtype="int32", name="segment_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + # Embed tokens, positions, and segment ids. + tokens = self.token_embedding(token_id_input) + positions = self.position_embedding(tokens) + segments = self.segment_embedding(segment_id_input) + # Sum, normalize and apply dropout to embeddings. + x = self.embeddings_add((tokens, positions, segments)) + x = self.embeddings_layer_norm(x) + x = self.embeddings_dropout(x) + x = self.embeddings_projection(x) + # Call transformer layers with repeated groups. + num_calls_per_group = num_layers // num_groups + for group in self.transformer_layers: + for _ in range(num_calls_per_group): + for transformer_layer in group: + x = transformer_layer(x, padding_mask=padding_mask_input) + # Construct the two ALBERT outputs. The pooled output is a dense layer + # on top of the [CLS] token. + sequence_output = x + cls_token_index = 0 + pooled_output = self.pooled_dense(x[:, cls_token_index, :]) super().__init__( inputs={ "token_ids": token_id_input, "segment_ids": segment_id_input, - "padding_mask": padding_mask, + "padding_mask": padding_mask_input, }, outputs={ "sequence_output": sequence_output, @@ -231,7 +232,8 @@ def call(x, padding_mask): }, **kwargs, ) - # All references to `self` below this line + + # === Config === self.vocabulary_size = vocabulary_size self.num_layers = num_layers self.num_heads = num_heads @@ -244,7 +246,6 @@ def call(x, padding_mask): self.max_sequence_length = max_sequence_length self.num_segments = num_segments self.cls_token_index = cls_token_index - self.token_embedding = token_embedding_layer def get_config(self): config = super().get_config() diff --git a/keras_nlp/models/albert/albert_classifier.py b/keras_nlp/models/albert/albert_classifier.py index b0ed7bca7c..32a4e0847d 100644 --- a/keras_nlp/models/albert/albert_classifier.py +++ b/keras_nlp/models/albert/albert_classifier.py @@ -155,30 +155,39 @@ def __init__( dropout=0.1, **kwargs, ): - inputs = backbone.input - pooled = backbone(inputs)["pooled_output"] - pooled = keras.layers.Dropout(dropout)(pooled) - outputs = keras.layers.Dense( + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.output_dense = keras.layers.Dense( num_classes, kernel_initializer=albert_kernel_initializer(), activation=activation, + dtype=backbone.dtype_policy, name="logits", - )(pooled) - # Instantiate using Functional API Model constructor + ) + self.output_dropout = keras.layers.Dropout( + dropout, + dtype=backbone.dtype_policy, + name="output_dropout", + ) + + # === Functional Model === + inputs = backbone.input + pooled = backbone(inputs)["pooled_output"] + pooled = self.output_dropout(pooled) + outputs = self.output_dense(pooled) super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - # All references to `self` below this line - self._backbone = backbone - self._preprocessor = preprocessor + + # === Config === self.num_classes = num_classes self.activation = keras.activations.get(activation) self.dropout = dropout - # Default compilation + # === Default compilation === logit_output = self.activation == keras.activations.linear self.compile( loss=keras.losses.SparseCategoricalCrossentropy( diff --git a/keras_nlp/models/albert/albert_masked_lm.py b/keras_nlp/models/albert/albert_masked_lm.py index e95af7c207..1958713b9f 100644 --- a/keras_nlp/models/albert/albert_masked_lm.py +++ b/keras_nlp/models/albert/albert_masked_lm.py @@ -97,32 +97,36 @@ class AlbertMaskedLM(Task): """ def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.masked_lm_head = MaskedLMHead( + vocabulary_size=backbone.vocabulary_size, + token_embedding=backbone.token_embedding, + intermediate_activation=gelu_approximate, + kernel_initializer=albert_kernel_initializer(), + dtype=backbone.dtype_policy, + name="mlm_head", + ) + + # === Functional Model === inputs = { **backbone.input, "mask_positions": keras.Input( shape=(None,), dtype="int32", name="mask_positions" ), } - backbone_outputs = backbone(backbone.input) - outputs = MaskedLMHead( - vocabulary_size=backbone.vocabulary_size, - token_embedding=backbone.token_embedding, - intermediate_activation=gelu_approximate, - kernel_initializer=albert_kernel_initializer(), - name="mlm_head", - )(backbone_outputs["sequence_output"], inputs["mask_positions"]) - + outputs = self.masked_lm_head( + backbone_outputs["sequence_output"], inputs["mask_positions"] + ) super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, - **kwargs + **kwargs, ) - self.backbone = backbone - self.preprocessor = preprocessor - + # === Default compilation === self.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(5e-5), diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py index 79d3a36bbb..b9bf693c17 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py @@ -43,7 +43,7 @@ def setUp(self): self.input_data = ["the quick brown fox"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=AlbertMaskedLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/albert/albert_preprocessor.py b/keras_nlp/models/albert/albert_preprocessor.py index 5d5628a729..19f4bd9a7b 100644 --- a/keras_nlp/models/albert/albert_preprocessor.py +++ b/keras_nlp/models/albert/albert_preprocessor.py @@ -158,9 +158,9 @@ def __init__( ): super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.truncate = truncate self.sequence_length = sequence_length - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -195,6 +195,17 @@ def call(self, x, y=None, sample_weight=None): } return pack_x_y_sample_weight(x, y, sample_weight) + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return AlbertTokenizer diff --git a/keras_nlp/models/albert/albert_preprocessor_test.py b/keras_nlp/models/albert/albert_preprocessor_test.py index 7d6fb4cfd4..ad5da8a47b 100644 --- a/keras_nlp/models/albert/albert_preprocessor_test.py +++ b/keras_nlp/models/albert/albert_preprocessor_test.py @@ -40,7 +40,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=AlbertPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 69da56593b..9c8cdaa60e 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.backend import config from keras_nlp.backend import keras from keras_nlp.utils.preset_utils import check_preset_class from keras_nlp.utils.preset_utils import load_from_preset @@ -21,26 +22,42 @@ @keras.saving.register_keras_serializable(package="keras_nlp") class Backbone(keras.Model): - def __init__(self, *args, **kwargs): + def __init__(self, *args, dtype=None, **kwargs): super().__init__(*args, **kwargs) - self._token_embedding = None self._functional_layer_ids = set( id(layer) for layer in self._flatten_layers() ) + self._initialized = True def __dir__(self): - # Temporary fixes for weight saving. This mimics the following PR for + if config.keras_3(): + return super().__dir__() + + # Temporary fixes for Keras 2 saving. This mimics the following PR for # older version of Keras: https://github.com/keras-team/keras/pull/18982 def filter_fn(attr): - if attr == "_layer_checkpoint_dependencies": + if attr in [ + "_layer_checkpoint_dependencies", + "transformer_layers", + "encoder_transformer_layers", + "decoder_transformer_layers", + ]: return False return id(getattr(self, attr)) not in self._functional_layer_ids return filter(filter_fn, super().__dir__()) def __setattr__(self, name, value): - # Work around torch setattr for properties. - if name in ["token_embedding"]: + # Work around setattr issues for Keras 2 and Keras 3 torch backend. + # Since all our state is covered by functional model we can route + # around custom setattr calls. + is_property = isinstance(getattr(type(self), name, None), property) + is_unitialized = not hasattr(self, "_initialized") + is_torch = config.backend() == "torch" + is_keras_2 = not config.keras_3() + if is_torch and (is_property or is_unitialized): + return object.__setattr__(self, name, value) + if is_keras_2 and is_unitialized: return object.__setattr__(self, name, value) return super().__setattr__(name, value) @@ -48,22 +65,17 @@ def __setattr__(self, name, value): def token_embedding(self): """A `keras.layers.Embedding` instance for embedding token ids. - This layer integer token ids to the hidden dim of the model. + This layer embeds integer token ids to the hidden dim of the model. """ return self._token_embedding @token_embedding.setter def token_embedding(self, value): - # Workaround tf.keras h5 checkpoint loading, which is sensitive to layer - # count mismatches and does not deduplicate layers. This could go away - # if we update our checkpoints to the newer `.weights.h5` format. - self._setattr_tracking = False self._token_embedding = value - self._setattr_tracking = True def get_config(self): - # Don't chain to super here. The default `get_config()` for functional - # models is nested and cannot be passed to our Backbone constructors. + # Don't chain to super here. `get_config()` for functional models is + # a nested layer config and cannot be passed to Backbone constructors. return { "name": self.name, "trainable": self.trainable, @@ -140,3 +152,80 @@ def from_preset(calling_cls, *args, **kwargs): example_preset_name=next(iter(cls.presets), ""), preset_names='", "'.join(cls.presets), )(cls.from_preset.__func__) + + def enable_lora(self, rank): + """Enable Lora on the backbone. + + Calling this method will freeze all weights on the backbone, + while enabling Lora on the query & value `EinsumDense` layers + of the attention layers. + """ + target_names = ["query_dense", "value_dense", "query", "value"] + self.trainable = True + self._lora_enabled_layers = [] + self._lora_rank = rank + for layer in self._flatten_layers(include_self=False): + layer.trainable = False + all_layers = self._flatten_layers(include_self=False) + all_layers = [lyr for lyr in all_layers if lyr.weights] + for i, layer in enumerate(all_layers): + for name in target_names: + if layer.name == name: + if hasattr(layer, "enable_lora"): + layer.trainable = True + layer.enable_lora(rank) + self._lora_enabled_layers.append(i) + + def save_lora_weights(self, filepath): + if not getattr(self, "_lora_enabled_layers", []): + raise ValueError( + "There are no lora-enabled layers in this model. " + "Make sure to call `.enable_lora(rank)` first." + ) + if not str(filepath).endswith(".lora.h5"): + raise ValueError( + "The filename must end in `.lora.h5`. " + f"Received: filepath={filepath}" + ) + + store = keras.src.saving.saving_lib.H5IOStore(filepath, mode="w") + lora_store = store.make("lora") + lora_store["rank"] = self._lora_rank + # We cannot identify layers by name since names are non-unique, + # so we identify them by index in the topologically sorted list + # of layers that have weights. + all_layers = self._flatten_layers(include_self=False) + all_layers = [lyr for lyr in all_layers if lyr.weights] + for layer_index in self._lora_enabled_layers: + # We only lora the einsumdense layers, + # so the factored weights are always named `kernel` + layer = all_layers[layer_index] + inner_store = store.make(f"lora/{layer_index}") + inner_store["lora_kernel_a"] = layer.lora_kernel_a + inner_store["lora_kernel_b"] = layer.lora_kernel_b + store.close() + + def load_lora_weights(self, filepath): + store = keras.src.saving.saving_lib.H5IOStore(filepath, mode="r") + lora_store = store.get("lora") + rank = int(lora_store["rank"][()]) + + if not getattr(self, "_lora_enabled_layers", []): + self.enable_lora(rank) + else: + if self._lora_rank != rank: + raise ValueError( + f"The Lora rank expected by file '{filepath}' " + f"is rank={rank}, but the model was called with " + f"`.enable_lora(rank={self._lora_rank})`. " + "Both ranks must match." + ) + all_layers = self._flatten_layers(include_self=False) + all_layers = [lyr for lyr in all_layers if lyr.weights] + for layer_index in self._lora_enabled_layers: + layer = all_layers[layer_index] + lora_kernel_a = store.get(f"lora/{layer_index}")["lora_kernel_a"] + lora_kernel_b = store.get(f"lora/{layer_index}")["lora_kernel_b"] + layer.lora_kernel_a.assign(lora_kernel_a) + layer.lora_kernel_b.assign(lora_kernel_b) + store.close() diff --git a/keras_nlp/models/bart/bart_backbone.py b/keras_nlp/models/bart/bart_backbone.py index 2679b84a9f..803d5a2a9f 100644 --- a/keras_nlp/models/bart/bart_backbone.py +++ b/keras_nlp/models/bart/bart_backbone.py @@ -60,6 +60,10 @@ class BartBackbone(Backbone): can consume. If None, `max_sequence_length` uses the value from sequence length. This determines the variable shape for positional embeddings. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. Examples: ```python @@ -100,125 +104,129 @@ def __init__( intermediate_dim, dropout=0.1, max_sequence_length=1024, + dtype=None, **kwargs, ): - # Encoder inputs - encoder_token_id_input = keras.Input( - shape=(None,), dtype="int32", name="encoder_token_ids" - ) - encoder_padding_mask = keras.Input( - shape=(None,), dtype="int32", name="encoder_padding_mask" - ) - - # Decoder inputs. - decoder_token_id_input = keras.Input( - shape=(None,), dtype="int32", name="decoder_token_ids" - ) - decoder_padding_mask = keras.Input( - shape=(None,), dtype="int32", name="decoder_padding_mask" - ) - - # Token embedding layer. This layer is shared by encoder and decoder. - token_embedding_layer = ReversibleEmbedding( + # === Layers === + self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, embeddings_initializer=bart_kernel_initializer(), + dtype=dtype, name="token_embedding", ) - - # ===== Encoder ===== - - # Embed tokens and positions. - token_embedding = token_embedding_layer(encoder_token_id_input) - # Position embedding parameters are not shared by encode and decoder. - position_embedding = PositionEmbedding( + self.encoder_position_embedding = PositionEmbedding( initializer=bart_kernel_initializer(), sequence_length=max_sequence_length, + dtype=dtype, name="encoder_position_embedding", - )(token_embedding) - - # Sum, normalize and apply dropout to embeddings. - x = keras.layers.Add(name="encoder_embeddings_add")( - (token_embedding, position_embedding) ) - x = keras.layers.LayerNormalization( - name="encoder_embeddings_layer_norm", + self.encoder_embeddings_add = keras.layers.Add( + dtype=dtype, + name="encoder_embeddings_add", + ) + self.encoder_embeddings_layer_norm = keras.layers.LayerNormalization( axis=-1, epsilon=1e-5, - dtype="float32", - )(x) - x = keras.layers.Dropout( + dtype=dtype, + name="encoder_embeddings_layer_norm", + ) + self.encoder_embeddings_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="encoder_embeddings_dropout", - )(x) - - # Apply successive transformer encoder blocks. + ) + self.encoder_transformer_layers = [] for i in range(num_layers): - x = TransformerEncoder( + layer = TransformerEncoder( num_heads=num_heads, intermediate_dim=intermediate_dim, activation=keras.activations.gelu, dropout=dropout, layer_norm_epsilon=1e-5, kernel_initializer=bart_kernel_initializer(), + dtype=dtype, name=f"transformer_encoder_layer_{i}", - )(x, padding_mask=encoder_padding_mask) - - encoder_output = x - - # ===== Decoder ===== - - # Embed tokens and positions. - token_embedding = token_embedding_layer(decoder_token_id_input) - # Position embedding parameters are not shared by encode and decoder. - position_embedding = PositionEmbedding( + ) + self.encoder_transformer_layers.append(layer) + self.decoder_position_embedding = PositionEmbedding( initializer=bart_kernel_initializer(), sequence_length=max_sequence_length, + dtype=dtype, name="decoder_position_embedding", - )(token_embedding) - - # Sum, normalize and apply dropout to embeddings. - x = keras.layers.Add(name="decoder_embeddings_add")( - (token_embedding, position_embedding) ) - x = keras.layers.LayerNormalization( - name="decoder_embeddings_layer_norm", + self.decoder_embeddings_add = keras.layers.Add( + dtype=dtype, + name="decoder_embeddings_add", + ) + self.decoder_embeddings_layer_norm = keras.layers.LayerNormalization( axis=-1, epsilon=1e-5, - dtype="float32", - )(x) - x = keras.layers.Dropout( + dtype=dtype, + name="decoder_embeddings_layer_norm", + ) + self.decoder_embeddings_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="decoder_embeddings_dropout", - )(x) - - # Apply successive transformer decoder blocks. + ) + self.decoder_transformer_layers = [] for i in range(num_layers): - transformer_decoder_layer = TransformerDecoder( + layer = TransformerDecoder( intermediate_dim=intermediate_dim, num_heads=num_heads, dropout=dropout, activation=keras.activations.gelu, layer_norm_epsilon=1e-5, kernel_initializer=bart_kernel_initializer(), + dtype=dtype, name=f"transformer_decoder_layer_{i}", ) - x = transformer_decoder_layer( + self.decoder_transformer_layers.append(layer) + + # === Functional Model === + encoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_token_ids" + ) + encoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_padding_mask" + ) + decoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_token_ids" + ) + decoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_padding_mask" + ) + # Encoder. + tokens = self.token_embedding(encoder_token_id_input) + positions = self.encoder_position_embedding(tokens) + x = self.encoder_embeddings_add((tokens, positions)) + x = self.encoder_embeddings_layer_norm(x) + x = self.encoder_embeddings_dropout(x) + for transformer_layer in self.encoder_transformer_layers: + x = transformer_layer(x, padding_mask=encoder_padding_mask_input) + encoder_output = x + # Decoder. + tokens = self.token_embedding(decoder_token_id_input) + positions = self.decoder_position_embedding(tokens) + x = self.decoder_embeddings_add((tokens, positions)) + x = self.decoder_embeddings_layer_norm(x) + x = self.decoder_embeddings_dropout(x) + for transformer_layer in self.decoder_transformer_layers: + x = transformer_layer( decoder_sequence=x, encoder_sequence=encoder_output, - decoder_padding_mask=decoder_padding_mask, - encoder_padding_mask=encoder_padding_mask, + decoder_padding_mask=decoder_padding_mask_input, + encoder_padding_mask=encoder_padding_mask_input, ) - decoder_output = x - # Instantiate using Functional API Model constructor super().__init__( inputs={ "encoder_token_ids": encoder_token_id_input, - "encoder_padding_mask": encoder_padding_mask, + "encoder_padding_mask": encoder_padding_mask_input, "decoder_token_ids": decoder_token_id_input, - "decoder_padding_mask": decoder_padding_mask, + "decoder_padding_mask": decoder_padding_mask_input, }, outputs={ "encoder_sequence_output": encoder_output, @@ -227,7 +235,7 @@ def __init__( **kwargs, ) - # All references to `self` below this line + # === Config === self.vocabulary_size = vocabulary_size self.num_layers = num_layers self.num_heads = num_heads @@ -235,7 +243,6 @@ def __init__( self.intermediate_dim = intermediate_dim self.dropout = dropout self.max_sequence_length = max_sequence_length - self.token_embedding = token_embedding_layer def get_config(self): config = super().get_config() diff --git a/keras_nlp/models/bart/bart_preprocessor.py b/keras_nlp/models/bart/bart_preprocessor.py index ffe2148839..3310b1e532 100644 --- a/keras_nlp/models/bart/bart_preprocessor.py +++ b/keras_nlp/models/bart/bart_preprocessor.py @@ -140,10 +140,10 @@ def __init__( ): super().__init__(**kwargs) self.tokenizer = tokenizer - self.encoder_sequence_length = encoder_sequence_length - self.decoder_sequence_length = decoder_sequence_length self.encoder_packer = None self.decoder_packer = None + self.encoder_sequence_length = encoder_sequence_length + self.decoder_sequence_length = decoder_sequence_length def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -174,7 +174,17 @@ def build(self, input_shape): ) self.built = True - def call(self, x, y=None, sample_weight=None): + def call( + self, + x, + y=None, + sample_weight=None, + *, + encoder_sequence_length=None, + decoder_sequence_length=None, + # `sequence_length` is an alias for `decoder_sequence_length` + sequence_length=None, + ): if not ( isinstance(x, dict) and all(k in x for k in ("encoder_text", "decoder_text")) @@ -184,6 +194,12 @@ def call(self, x, y=None, sample_weight=None): f' and `"decoder_text"`. Received x={x}.' ) + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + encoder_text = x["encoder_text"] decoder_text = x["decoder_text"] @@ -199,12 +215,14 @@ def call(self, x, y=None, sample_weight=None): encoder_inputs = self.tokenizer(encoder_text[0]) encoder_token_ids, encoder_padding_mask = self.encoder_packer( - encoder_inputs + encoder_inputs, + sequence_length=encoder_sequence_length, ) decoder_inputs = self.tokenizer(decoder_text[0]) decoder_token_ids, decoder_padding_mask = self.decoder_packer( - decoder_inputs + decoder_inputs, + sequence_length=decoder_sequence_length, ) x = { @@ -226,6 +244,37 @@ def get_config(self): ) return config + @property + def encoder_sequence_length(self): + """The padded length of encoder input sequences.""" + return self._encoder_sequence_length + + @encoder_sequence_length.setter + def encoder_sequence_length(self, value): + self._encoder_sequence_length = value + if self.encoder_packer is not None: + self.encoder_packer.sequence_length = value + + @property + def decoder_sequence_length(self): + """The padded length of decoder input sequences.""" + return self._decoder_sequence_length + + @decoder_sequence_length.setter + def decoder_sequence_length(self, value): + self._decoder_sequence_length = value + if self.decoder_packer is not None: + self.decoder_packer.sequence_length = value + + @property + def sequence_length(self): + """Alias for `decoder_sequence_length`.""" + return self.decoder_sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self.decoder_sequence_length = value + @classproperty def tokenizer_cls(cls): return BartTokenizer diff --git a/keras_nlp/models/bart/bart_preprocessor_test.py b/keras_nlp/models/bart/bart_preprocessor_test.py index 23cb7cae79..7872e35efa 100644 --- a/keras_nlp/models/bart/bart_preprocessor_test.py +++ b/keras_nlp/models/bart/bart_preprocessor_test.py @@ -46,7 +46,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=BartPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, @@ -60,6 +60,7 @@ def test_preprocessor_basics(self): [1], # Pass through labels. [1.0], # Pass through sample_weights. ), + token_id_key="decoder_token_ids", ) def test_error_multi_segment_input(self): diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm.py b/keras_nlp/models/bart/bart_seq_2_seq_lm.py index 2131519ce3..c530555b3d 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm.py @@ -185,24 +185,21 @@ def __init__( preprocessor=None, **kwargs, ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === inputs = backbone.input hidden_states = backbone(inputs)["decoder_sequence_output"] outputs = backbone.token_embedding(hidden_states, reverse=True) - - # Instantiate using Functional API Model constructor. super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - self.backbone = backbone - self.preprocessor = preprocessor - self.generate_function = None - self._sampler = None - - # Default compilation + # === Default compilation === self.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(2e-5), @@ -280,33 +277,28 @@ def call_decoder_with_cache( cross-attention layer. """ # Embedding layers. - token_embedding = self.backbone.get_layer("token_embedding")( - decoder_token_ids + tokens = self.backbone.token_embedding(decoder_token_ids) + positions = self.backbone.decoder_position_embedding( + tokens, + start_index=self_attention_cache_update_index, ) - position_embedding = self.backbone.get_layer( - "decoder_position_embedding" - )(token_embedding, start_index=self_attention_cache_update_index) - # Sum, normalize and apply dropout to embeddings. - x = self.backbone.get_layer("decoder_embeddings_add")( - (token_embedding, position_embedding) - ) - x = self.backbone.get_layer("decoder_embeddings_layer_norm")(x) - x = self.backbone.get_layer("decoder_embeddings_dropout")(x) + x = self.backbone.decoder_embeddings_add((tokens, positions)) + x = self.backbone.decoder_embeddings_layer_norm(x) + x = self.backbone.decoder_embeddings_dropout(x) # Every decoder layer has a separate cache for the self-attention layer # and the cross-attention layer. We update all of them separately. self_attention_caches = [] cross_attention_caches = [] - for i in range(self.backbone.num_layers): + for i, layer in enumerate(self.backbone.decoder_transformer_layers): current_self_attention_cache = self_attention_cache[:, i, ...] current_cross_attention_cache = cross_attention_cache[:, i, ...] - ( x, next_self_attention_cache, next_cross_attention_cache, - ) = self.backbone.get_layer(f"transformer_decoder_layer_{i}")( + ) = layer( decoder_sequence=x, encoder_sequence=encoder_hidden_states, encoder_padding_mask=encoder_padding_mask, @@ -315,7 +307,6 @@ def call_decoder_with_cache( cross_attention_cache=current_cross_attention_cache, cross_attention_cache_update_index=cross_attention_cache_update_index, ) - if self_attention_cache_update_index is not None: self_attention_caches.append(next_self_attention_cache) if cross_attention_cache_update_index is not None: @@ -337,26 +328,13 @@ def call_decoder_with_cache( def call_encoder(self, token_ids, padding_mask): """Does a forward pass on the encoder and returns the encoder output.""" - - # Embedding layers. - token_embedding = self.backbone.get_layer("token_embedding")(token_ids) - position_embedding = self.backbone.get_layer( - "encoder_position_embedding" - )(token_embedding) - - # Sum, normalize and apply dropout to embeddings. - x = self.backbone.get_layer("encoder_embeddings_add")( - (token_embedding, position_embedding) - ) - x = self.backbone.get_layer("encoder_embeddings_layer_norm")(x) - x = self.backbone.get_layer("encoder_embeddings_dropout")(x) - - # Transformer encoder layers. - for i in range(self.backbone.num_layers): - x = self.backbone.get_layer(f"transformer_encoder_layer_{i}")( - x, padding_mask=padding_mask - ) - + tokens = self.backbone.token_embedding(token_ids) + positions = self.backbone.encoder_position_embedding(tokens) + x = self.backbone.decoder_embeddings_add((tokens, positions)) + x = self.backbone.encoder_embeddings_layer_norm(x) + x = self.backbone.encoder_embeddings_dropout(x) + for transformer_layer in self.backbone.encoder_transformer_layers: + x = transformer_layer(x, padding_mask=padding_mask) return x def _initialize_cache(self, encoder_token_ids, decoder_token_ids): @@ -501,6 +479,7 @@ def repeat_tensor(x): mask=decoder_padding_mask, end_token_id=end_token_id, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py index 3d398d29d1..1c72e6e935 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py @@ -124,28 +124,17 @@ class BartSeq2SeqLMPreprocessor(BartPreprocessor): ``` """ - def __init__( + def call( self, - tokenizer, - encoder_sequence_length=1024, - decoder_sequence_length=1024, - **kwargs + x, + y=None, + sample_weight=None, + *, + encoder_sequence_length=None, + decoder_sequence_length=None, + # `sequence_length` is an alias for `decoder_sequence_length` + sequence_length=None, ): - # Since we truncate the last token from `decoder_token_ids`, we need to - # forcefully set the `decoder_sequence_length` to one greater than the - # value passed. - super().__init__( - tokenizer=tokenizer, - encoder_sequence_length=encoder_sequence_length, - decoder_sequence_length=decoder_sequence_length + 1, - **kwargs - ) - - # Maintain a private copy of the sequence lengths for config purposes. - self._encoder_sequence_length = encoder_sequence_length - self._decoder_sequence_length = decoder_sequence_length - - def call(self, x, y=None, sample_weight=None): if y is not None or sample_weight is not None: logging.warning( "`BartSeq2SeqLMPreprocessor` infers `y` and `sample_weight` " @@ -154,7 +143,17 @@ def call(self, x, y=None, sample_weight=None): "These values will be ignored." ) - x = super().call(x) + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + + x = super().call( + x, + encoder_sequence_length=encoder_sequence_length, + decoder_sequence_length=decoder_sequence_length + 1, + ) decoder_token_ids = x.pop("decoder_token_ids") decoder_padding_mask = x.pop("decoder_padding_mask") @@ -173,6 +172,10 @@ def call(self, x, y=None, sample_weight=None): def generate_preprocess( self, x, + *, + encoder_sequence_length=None, + # `sequence_length` is an alias for `decoder_sequence_length` + decoder_sequence_length=None, sequence_length=None, ): """Convert encoder and decoder input strings to integer token inputs for generation. @@ -190,10 +193,6 @@ def generate_preprocess( if not self.built: self.build(None) - # If `sequence_length` is not provided, we use the default value. - if sequence_length is None: - sequence_length = self._decoder_sequence_length - if isinstance(x, dict): encoder_text = x["encoder_text"] decoder_text = x["decoder_text"] @@ -202,6 +201,12 @@ def generate_preprocess( # Initialize empty prompt for the decoder. decoder_text = tf.fill((tf.shape(encoder_text)[0],), "") + if encoder_sequence_length is None: + encoder_sequence_length = self.encoder_sequence_length + decoder_sequence_length = decoder_sequence_length or sequence_length + if decoder_sequence_length is None: + decoder_sequence_length = self.decoder_sequence_length + # Tokenize and pack the encoder inputs. # TODO: Remove `[0]` once we have shifted to `MultiSegmentPacker`. encoder_text = convert_inputs_to_list_of_tensor_segments(encoder_text)[ @@ -209,7 +214,8 @@ def generate_preprocess( ] encoder_token_ids = self.tokenizer(encoder_text) encoder_token_ids, encoder_padding_mask = self.encoder_packer( - encoder_token_ids + encoder_token_ids, + sequence_length=encoder_sequence_length, ) # Tokenize and pack the decoder inputs. @@ -219,7 +225,7 @@ def generate_preprocess( decoder_token_ids = self.tokenizer(decoder_text) decoder_token_ids, decoder_padding_mask = self.decoder_packer( decoder_token_ids, - sequence_length=sequence_length, + sequence_length=decoder_sequence_length, add_end_value=False, ) @@ -261,16 +267,6 @@ def generate_postprocess( ) return self.tokenizer.detokenize(decoder_token_ids) - def get_config(self): - config = super().get_config() - config.update( - { - "encoder_sequence_length": self._encoder_sequence_length, - "decoder_sequence_length": self._decoder_sequence_length, - } - ) - return config - @classproperty def presets(cls): return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py index 33fbd5fc3a..2f40e69722 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py @@ -45,7 +45,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=BartSeq2SeqLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, @@ -59,6 +59,7 @@ def test_preprocessor_basics(self): [[0, 4, 5, 4, 7, 2, 1, 1]], [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]], ), + token_id_key="decoder_token_ids", ) def test_generate_preprocess(self): diff --git a/keras_nlp/models/bert/bert_backbone.py b/keras_nlp/models/bert/bert_backbone.py index 174b0f0e42..2248260da7 100644 --- a/keras_nlp/models/bert/bert_backbone.py +++ b/keras_nlp/models/bert/bert_backbone.py @@ -61,6 +61,10 @@ class BertBackbone(Backbone): embeddings. num_segments: int. The number of types that the 'segment_ids' input can take. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. Examples: ```python @@ -97,84 +101,96 @@ def __init__( dropout=0.1, max_sequence_length=512, num_segments=2, + dtype=None, **kwargs, ): - # Index of classification token in the vocabulary - cls_token_index = 0 - # Inputs - token_id_input = keras.Input( - shape=(None,), dtype="int32", name="token_ids" - ) - segment_id_input = keras.Input( - shape=(None,), dtype="int32", name="segment_ids" - ) - padding_mask = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" - ) - - # Embed tokens, positions, and segment ids. - token_embedding_layer = ReversibleEmbedding( + # === Layers === + self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, embeddings_initializer=bert_kernel_initializer(), + dtype=dtype, name="token_embedding", ) - token_embedding = token_embedding_layer(token_id_input) - position_embedding = PositionEmbedding( + self.position_embedding = PositionEmbedding( initializer=bert_kernel_initializer(), sequence_length=max_sequence_length, + dtype=dtype, name="position_embedding", - )(token_embedding) - segment_embedding = keras.layers.Embedding( + ) + self.segment_embedding = keras.layers.Embedding( input_dim=num_segments, output_dim=hidden_dim, embeddings_initializer=bert_kernel_initializer(), + dtype=dtype, name="segment_embedding", - )(segment_id_input) - - # Sum, normalize and apply dropout to embeddings. - x = keras.layers.Add()( - (token_embedding, position_embedding, segment_embedding) ) - x = keras.layers.LayerNormalization( - name="embeddings_layer_norm", + self.embeddings_add = keras.layers.Add( + dtype=dtype, + name="embeddings_add", + ) + self.embeddings_layer_norm = keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, - dtype="float32", - )(x) - x = keras.layers.Dropout( + dtype=dtype, + name="embeddings_layer_norm", + ) + self.embeddings_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="embeddings_dropout", - )(x) - - # Apply successive transformer encoder blocks. + ) + self.transformer_layers = [] for i in range(num_layers): - x = TransformerEncoder( + layer = TransformerEncoder( num_heads=num_heads, intermediate_dim=intermediate_dim, activation=gelu_approximate, dropout=dropout, layer_norm_epsilon=1e-12, kernel_initializer=bert_kernel_initializer(), + dtype=dtype, name=f"transformer_layer_{i}", - )(x, padding_mask=padding_mask) - - # Construct the two BERT outputs. The pooled output is a dense layer on - # top of the [CLS] token. - sequence_output = x - pooled_output = keras.layers.Dense( + ) + self.transformer_layers.append(layer) + self.pooled_dense = keras.layers.Dense( hidden_dim, kernel_initializer=bert_kernel_initializer(), activation="tanh", + dtype=dtype, name="pooled_dense", - )(x[:, cls_token_index, :]) + ) - # Instantiate using Functional API Model constructor + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + segment_id_input = keras.Input( + shape=(None,), dtype="int32", name="segment_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + # Embed tokens, positions, and segment ids. + tokens = self.token_embedding(token_id_input) + positions = self.position_embedding(tokens) + segments = self.segment_embedding(segment_id_input) + # Sum, normalize and apply dropout to embeddings. + x = self.embeddings_add((tokens, positions, segments)) + x = self.embeddings_layer_norm(x) + x = self.embeddings_dropout(x) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, padding_mask=padding_mask_input) + # Construct the two BERT outputs. The pooled output is a dense layer on + # top of the [CLS] token. + sequence_output = x + cls_token_index = 0 + pooled_output = self.pooled_dense(x[:, cls_token_index, :]) super().__init__( inputs={ "token_ids": token_id_input, "segment_ids": segment_id_input, - "padding_mask": padding_mask, + "padding_mask": padding_mask_input, }, outputs={ "sequence_output": sequence_output, @@ -183,7 +199,7 @@ def __init__( **kwargs, ) - # All references to `self` below this line + # === Config === self.vocabulary_size = vocabulary_size self.num_layers = num_layers self.num_heads = num_heads @@ -193,7 +209,6 @@ def __init__( self.max_sequence_length = max_sequence_length self.num_segments = num_segments self.cls_token_index = cls_token_index - self.token_embedding = token_embedding_layer def get_config(self): config = super().get_config() diff --git a/keras_nlp/models/bert/bert_classifier.py b/keras_nlp/models/bert/bert_classifier.py index 2a9aa548bf..09d2b8810c 100644 --- a/keras_nlp/models/bert/bert_classifier.py +++ b/keras_nlp/models/bert/bert_classifier.py @@ -140,30 +140,39 @@ def __init__( dropout=0.1, **kwargs, ): - inputs = backbone.input - pooled = backbone(inputs)["pooled_output"] - pooled = keras.layers.Dropout(dropout)(pooled) - outputs = keras.layers.Dense( + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.output_dropout = keras.layers.Dropout( + dropout, + dtype=backbone.dtype_policy, + name="classifier_dropout", + ) + self.output_dense = keras.layers.Dense( num_classes, kernel_initializer=bert_kernel_initializer(), activation=activation, + dtype=backbone.dtype_policy, name="logits", - )(pooled) - # Instantiate using Functional API Model constructor + ) + + # === Functional Model === + inputs = backbone.input + pooled = backbone(inputs)["pooled_output"] + pooled = self.output_dropout(pooled) + outputs = self.output_dense(pooled) super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - # All references to `self` below this line - self.backbone = backbone - self.preprocessor = preprocessor + + # === Config === self.num_classes = num_classes self.activation = keras.activations.get(activation) self.dropout = dropout - # Default compilation + # === Default compilation === logit_output = self.activation == keras.activations.linear self.compile( loss=keras.losses.SparseCategoricalCrossentropy( diff --git a/keras_nlp/models/bert/bert_masked_lm.py b/keras_nlp/models/bert/bert_masked_lm.py index d4c12d1091..17b9669619 100644 --- a/keras_nlp/models/bert/bert_masked_lm.py +++ b/keras_nlp/models/bert/bert_masked_lm.py @@ -101,6 +101,19 @@ def __init__( preprocessor=None, **kwargs, ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.masked_lm_head = MaskedLMHead( + vocabulary_size=backbone.vocabulary_size, + token_embedding=backbone.token_embedding, + intermediate_activation="gelu", + kernel_initializer=bert_kernel_initializer(), + dtype=backbone.dtype_policy, + name="mlm_head", + ) + + # === Functional Model === inputs = { **backbone.input, "mask_positions": keras.Input( @@ -108,22 +121,16 @@ def __init__( ), } backbone_outputs = backbone(backbone.input) - outputs = MaskedLMHead( - vocabulary_size=backbone.vocabulary_size, - token_embedding=backbone.token_embedding, - intermediate_activation="gelu", - kernel_initializer=bert_kernel_initializer(), - name="mlm_head", - )(backbone_outputs["sequence_output"], inputs["mask_positions"]) - - # Instantiate using Functional API Model constructor + outputs = self.masked_lm_head( + backbone_outputs["sequence_output"], inputs["mask_positions"] + ) super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - # All references to `self` below this line + + # === Default compilation === self.backbone = backbone self.preprocessor = preprocessor self.compile( diff --git a/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py b/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py index ff58962215..479d9e879b 100644 --- a/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py @@ -39,7 +39,7 @@ def setUp(self): self.input_data = ["the quick brown fox"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=BertMaskedLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/bert/bert_preprocessor.py b/keras_nlp/models/bert/bert_preprocessor.py index bad38f22a5..02f5a45985 100644 --- a/keras_nlp/models/bert/bert_preprocessor.py +++ b/keras_nlp/models/bert/bert_preprocessor.py @@ -139,9 +139,9 @@ def __init__( ): super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.sequence_length = sequence_length self.truncate = truncate - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -176,6 +176,17 @@ def get_config(self): ) return config + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return BertTokenizer diff --git a/keras_nlp/models/bert/bert_preprocessor_test.py b/keras_nlp/models/bert/bert_preprocessor_test.py index 6d1e5fee57..c109d1006d 100644 --- a/keras_nlp/models/bert/bert_preprocessor_test.py +++ b/keras_nlp/models/bert/bert_preprocessor_test.py @@ -36,7 +36,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=BertPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/bloom/__init__.py b/keras_nlp/models/bloom/__init__.py new file mode 100644 index 0000000000..ba0c2545e4 --- /dev/null +++ b/keras_nlp/models/bloom/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/models/bloom/bloom_attention.py b/keras_nlp/models/bloom/bloom_attention.py new file mode 100644 index 0000000000..e36c6fac62 --- /dev/null +++ b/keras_nlp/models/bloom/bloom_attention.py @@ -0,0 +1,185 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.alibi_bias import AlibiBias +from keras_nlp.utils.keras_utils import clone_initializer + + +class BloomAttention(keras.layers.Layer): + def __init__( + self, + num_heads, + dropout=0.0, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + **kwargs, + ): + super().__init__(**kwargs) + self.num_heads = num_heads + self.dropout = dropout + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + + def build(self, inputs_shape): + batch_size, seq_length, hidden_dim = inputs_shape + + self.head_dim = hidden_dim // self.num_heads + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + + self._query_dense = keras.layers.EinsumDense( + equation="btm,mnh->btnh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="query_dense", + ) + self._query_dense.build(inputs_shape) + + self._key_dense = keras.layers.EinsumDense( + equation="bsm,mnh->bsnh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="key_dense", + ) + self._key_dense.build(inputs_shape) + + self._value_dense = keras.layers.EinsumDense( + equation="bsm,mnh->bsnh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="value_dense", + ) + self._value_dense.build(inputs_shape) + + self._alibi_layer = AlibiBias( + dtype=self.dtype_policy, + ) + + self._output_dense = keras.layers.Dense( + hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="output_dense", + ) + self._output_dense.build(inputs_shape) + + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="dropout", + ) + self._softmax = keras.layers.Softmax( + dtype="float32", + name="softmax", + ) + + self.built = True + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + ): + batch_size, seq_length, hidden_dim = ops.shape(hidden_states) + + query = self._query_dense(hidden_states) + key = self._key_dense(hidden_states) + value = self._value_dense(hidden_states) + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key) + value = ops.slice_update(value_cache, start, value) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + + # query (batch_size, num_heads, query_length, head_dim) + query = ops.transpose(query, [0, 2, 1, 3]) + # value (batch_size, num_heads, kv_length, head_dim) + value = ops.transpose(value, [0, 2, 1, 3]) + # key (batch_size, num_heads, head_dim, kv_length) + key = ops.transpose(key, [0, 2, 3, 1]) + + attention_scores = ( + ops.matmul(query, key) * self.inv_norm_factor + ) # [batch_size, num_heads, query_length, kv_length] + attention_scores = self._alibi_layer(attention_scores) + attention_scores = self._softmax( + attention_scores, ops.expand_dims(attention_mask, 1) + ) + attention_scores = self._dropout_layer(attention_scores) + + attention_output = ops.matmul( + attention_scores, value + ) # [batch_size, num_heads, query_length, head_dim] + + attention_output = ops.transpose( + attention_output, [0, 2, 1, 3] + ) # [batch_size, query_length, num_heads, head_dim] + attention_output = ops.reshape( + attention_output, + [batch_size, seq_length, self.num_heads * self.head_dim], + ) # [batch_size, query_length, hidden_dim] + + attention_output = self._output_dense(attention_output) + attention_output = self._dropout_layer(attention_output) + + if cache is not None: + return attention_output, cache + + return attention_output + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "dropout": self.dropout, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + } + ) + return config diff --git a/keras_nlp/models/bloom/bloom_backbone.py b/keras_nlp/models/bloom/bloom_backbone.py new file mode 100644 index 0000000000..5e2251a693 --- /dev/null +++ b/keras_nlp/models/bloom/bloom_backbone.py @@ -0,0 +1,182 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy + +from keras_nlp.backend import keras +from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.bloom.bloom_decoder import BloomDecoder +from keras_nlp.models.bloom.bloom_presets import backbone_presets +from keras_nlp.utils.python_utils import classproperty + + +def _bloom_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras.saving.register_keras_serializable(package="keras_nlp") +class BloomBackbone(Backbone): + """A BLOOM decoder network. + + This network implements a Transformer-based decoder network, BigScience + Language Open-science Open-access Multilingual (BLOOM), as descriped in + ["BLOOM: A 176B-Parameter Open-Access Multilingual Language Model"](https://arxiv.org/pdf/2211.05100.pdf). + + The default constructor gives a fully customizable, randomly initialized + Bloom model with any number of layers, heads, and embedding dimensions. To + load preset architectures and weights, use the `from_preset()` constructor. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. The underlying model is provided by a + third party and subject to a separate license, available [here](https://huggingface.co/spaces/bigscience/license). + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_heads: int. The number of attention heads for each transformer. + The hidden size must be divisible by the number of attention heads. + hidden_dim: int. The dimensionality of the embeddings and hidden states. + intermediate_dim: int. The output dimension of the first Dense layer in + the MLP network of each transformer. + dropout: float. Dropout probability for the Transformer decoder. + layer_norm_epsilon: float. Epsilon for the layer normalization layers in + the transformer decoder. + max_sequence_length: int. The maximum sequence length that this decoder + can consume. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. + + Examples: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained BLOOM decoder. + model = keras_nlp.models.BloomBackbone.from_preset("bloom_560m_multi") + model(input_data) + + # Randomly initialized BLOOM decoder with a custom config. + model = keras_nlp.models.BloomBackbone( + vocabulary_size=10, + num_layers=2, + num_heads=2, + hidden_dim=32, + intermediate_dim=32*4, + dropout=0.0, + layer_norm_epsilon=1e-5, + max_sequence_length=128, + ) + model(input_data) + ``` + + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_heads, + hidden_dim, + intermediate_dim, + dropout=0.0, + layer_norm_epsilon=1e-5, + max_sequence_length=2048, + dtype=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + embeddings_initializer=_bloom_kernel_initializer(stddev=0.02), + tie_weights=False, + dtype=dtype, + name="token_embedding", + ) + self.embeddings_layer_norm = keras.layers.LayerNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="token_embedding_layernorm", + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = BloomDecoder( + num_heads=num_heads, + intermediate_dim=intermediate_dim, + dropout=dropout, + layer_norm_epsilon=layer_norm_epsilon, + dtype=dtype, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = keras.layers.LayerNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="final_layernorm", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + x = self.embeddings_layer_norm(x) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, decoder_padding_mask=padding_mask_input) + sequence_output = self.layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.dropout = dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.max_sequence_length = max_sequence_length + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + "max_sequence_length": self.max_sequence_length, + } + ) + return config + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bloom/bloom_backbone_test.py b/keras_nlp/models/bloom/bloom_backbone_test.py new file mode 100644 index 0000000000..83732e4945 --- /dev/null +++ b/keras_nlp/models/bloom/bloom_backbone_test.py @@ -0,0 +1,76 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.bloom.bloom_backbone import BloomBackbone +from keras_nlp.tests.test_case import TestCase + + +class BloomBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_heads": 4, + "hidden_dim": 8, + "intermediate_dim": 32, + "max_sequence_length": 10, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=BloomBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 8), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=BloomBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=BloomBackbone, + preset="bloom_560m_multi", + input_data={ + "token_ids": ops.array([[101, 1996, 4248, 102]], dtype="int32"), + "padding_mask": ops.ones((1, 4), dtype="int32"), + }, + expected_output_shape=(1, 4, 1024), + # The forward pass from a preset should be stable! + expected_partial_output=ops.array( + [2.4394186, 1.4131186, -2.7810357, -6.330823, -1.0599766] + ), + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in BloomBackbone.presets: + self.run_preset_test( + cls=BloomBackbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py new file mode 100644 index 0000000000..ceec2b67fb --- /dev/null +++ b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py @@ -0,0 +1,182 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from absl import logging + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.bloom.bloom_preprocessor import BloomPreprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras.saving.register_keras_serializable(package="keras_nlp") +class BloomCausalLMPreprocessor(BloomPreprocessor): + """BLOOM Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.BloomCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.BloomCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.BloomTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.BloomCausalLMPreprocessor.from_preset( + "bloom_560m_multi" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("League of legends") + preprocessor(sentence) + # Same output. + preprocessor("League of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) + preprocessor(sentences) + # Same output. + preprocessor(["Taco tuesday", "Fish taco please!"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`BloomCausalLMPreprocessor` generates `y` and `sample_weight` " + "based on your input data, but your data already contains `y` " + "or `sample_weight`. Your `y` and `sample_weight` will be " + "ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Covert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + """Covert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + if not self.built: + self.build(None) + + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + token_ids = ops.convert_to_numpy(token_ids) + padding_mask = ops.convert_to_numpy(padding_mask) + # Strip any special tokens during detokenization (e.g. the start and + # end markers). In the future we could make this configurable. + padding_mask = ( + padding_mask + & (token_ids != self.tokenizer.eos_token_id) + & (token_ids != self.tokenizer.bos_token_id) + ) + token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/bloom/bloom_causal_lm_preprocessor_test.py b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..a281519340 --- /dev/null +++ b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor_test.py @@ -0,0 +1,94 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.models.bloom.bloom_causal_lm_preprocessor import ( + BloomCausalLMPreprocessor, +) +from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_nlp.tests.test_case import TestCase + + +class BloomCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["", "", ""] + self.vocab += ["!", "air", "Ä air", "plane", "Ä at", "port"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ä  a", "Ä  t", "Ä  i", "Ä  b", "a i", "p l", "n e"] + self.merges += ["Ä a t", "p o", "r t", "Ä t h", "ai r", "pl a", "po rt"] + self.merges += ["Ä ai r", "Ä a i", "pla ne"] + self.tokenizer = BloomTokenizer( + vocabulary=self.vocab, + merges=self.merges, + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["airplane at airport"] + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=BloomCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 4, 6, 7, 5, 8, 2, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]], + }, + [[4, 6, 7, 5, 8, 2, 0, 0]], # Pass through labels. + [[1, 1, 1, 1, 1, 1, 0, 0]], # Pass through sample_weights. + ), + ) + + def test_no_start_end_token(self): + input_data = ["airplane at airport"] * 4 + + preprocessor = BloomCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[4, 6, 7, 5, 8, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[6, 7, 5, 8, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "airplane at airport" + preprocessor = BloomCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 4, 6, 7, 5, 8, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 4, 6, 7, 5, 8, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0], + } + preprocessor = BloomCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "airplane at airport") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in BloomCausalLMPreprocessor.presets: + self.run_preset_test( + cls=BloomCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/bloom/bloom_decoder.py b/keras_nlp/models/bloom/bloom_decoder.py new file mode 100644 index 0000000000..a0c62a2541 --- /dev/null +++ b/keras_nlp/models/bloom/bloom_decoder.py @@ -0,0 +1,209 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from keras_nlp.backend import keras +# from keras_nlp.backend import ops +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_nlp.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_nlp.models.bloom.bloom_attention import BloomAttention +from keras_nlp.utils.keras_utils import clone_initializer + + +class BloomDecoder(keras.layers.Layer): + def __init__( + self, + num_heads, + intermediate_dim, + dropout=0.0, + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + **kwargs, + ): + super().__init__(**kwargs) + + self.num_heads = num_heads + self.intermediate_dim = intermediate_dim + self.dropout = dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + + def build(self, decoder_sequence_shape): + hidden_dim = decoder_sequence_shape[-1] + head_dim = int(hidden_dim // self.num_heads) + + if head_dim * self.num_heads != hidden_dim: + raise ValueError( + f"`hidden_dim` must be divisible by num_heads (got `hidden_dim`" + f": {hidden_dim} and `num_heads`: {self.num_heads})." + ) + + self._pre_attention_layernorm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="pre_attention_layernorm", + ) + self._pre_attention_layernorm.build(decoder_sequence_shape) + + self._self_attention_layer = BloomAttention( + num_heads=self.num_heads, + dropout=self.dropout, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="self_attention", + ) + self._self_attention_layer.build(decoder_sequence_shape) + + self._post_attention_layernorm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="post_attention_layernorm", + ) + self._post_attention_layernorm.build(decoder_sequence_shape) + + self._mlp_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="mlp_intermediate_dense", + ) + self._mlp_intermediate_dense.build(decoder_sequence_shape) + + self._mlp_output_dense = keras.layers.Dense( + hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="mlp_output_dense", + ) + intermediate_shape = list(decoder_sequence_shape) + intermediate_shape[-1] = self.intermediate_dim + self._mlp_output_dense.build(tuple(intermediate_shape)) + + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, dtype=self.dtype_policy, name="dropout" + ) + + self.built = True + + def call( + self, + decoder_sequence, + decoder_padding_mask=None, + decoder_attention_mask=None, + attention_cache=None, + attention_cache_update_index=None, + use_causal_mask=True, + ): + self_attention_mask = self._compute_attention_mask( + decoder_sequence=decoder_sequence, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + use_causal_mask=use_causal_mask, + attention_cache=attention_cache, + attention_cache_update_index=attention_cache_update_index, + ) + + residual = decoder_sequence + x = self._pre_attention_layernorm(decoder_sequence) + + attention_output = self._self_attention_layer( + hidden_states=x, + attention_mask=self_attention_mask, + cache=attention_cache, + cache_update_index=attention_cache_update_index, + ) + + if attention_cache is None: + x = attention_output + else: + x, attention_cache = attention_output + + x = x + residual + residual = x + x = self._post_attention_layernorm(x) + x = self._mlp_intermediate_dense(x) + x = keras.activations.gelu(x, approximate=True) + x = self._mlp_output_dense(x) + x = self._dropout_layer(x) + x = x + residual + + if attention_cache is not None: + return x, attention_cache + else: + return x + + def _compute_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + use_causal_mask, + attention_cache, + attention_cache_update_index, + ): + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + if use_causal_mask: + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + if attention_cache is not None: + input_length = ops.shape(attention_cache)[2] + + causal_mask = compute_causal_mask( + batch_size, + input_length, + output_length, + ( + 0 + if attention_cache_update_index is None + else attention_cache_update_index + ), + ) + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + return decoder_mask + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + } + ) + return config + + def compute_output_shape(self, decoder_sequence_shape): + return decoder_sequence_shape diff --git a/keras_nlp/models/bloom/bloom_preprocessor.py b/keras_nlp/models/bloom/bloom_preprocessor.py new file mode 100644 index 0000000000..45003916ba --- /dev/null +++ b/keras_nlp/models/bloom/bloom_preprocessor.py @@ -0,0 +1,193 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.backend import keras +from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.models.bloom.bloom_presets import backbone_presets +from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight +from keras_nlp.utils.python_utils import classproperty + + +@keras.saving.register_keras_serializable(package="keras_nlp") +class BloomPreprocessor(Preprocessor): + """BLOOM preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do 2 things: + + - Tokenize the inputs using the `tokenizer`. + - Construct a dictionary with keys `"token_ids"`, `"padding_mask"`, that can + be passed directly to a `keras_nlp.models.BloomBackbone`. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + The call method of this layer accepts three arguments, `x`, `y`, and + `sample_weight`. `x` can be a python string or tensor representing a single + segment, a list of python strings representing a batch of single segments, + or a list of tensors representing multiple segments to be packed together. + `y` and `sample_weight` are both optional, can have any format, and will be + passed through unaltered. + + Args: + tokenizer: A `keras_nlp.models.BloomTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + + Directly calling the layer on data. + ```python + preprocessor = keras_nlp.models.BloomPreprocessor.from_preset( + "bloom_560m_multi" + ) + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize a batch of single sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Custom vocabulary. + features = ["a quick fox.", "a fox quick."] + vocab = {"": 0, "":1, "":2, "a": 3, "Ä quick": 4, "Ä fox": 5} + merges = ["Ä  q", "u i", "c k", "ui ck", "Ä q uick"] + merges += ["Ä  f", "o x", "Ä f ox"] + tokenizer = keras_nlp.models.BloomTokenizer( + vocabulary=vocab, + merges=merges, + ) + preprocessor = keras_nlp.models.BloomPreprocessor(tokenizer=tokenizer) + preprocessor("The quick brown fox jumped.") + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.BloomPreprocessor.from_preset( + "bloom_560m_multi" + ) + + text = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + label = tf.constant([1, 1]) + + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((text, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(text) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def __init__( + self, + tokenizer, + sequence_length=2048, + add_start_token=True, + add_end_token=True, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.packer = None + self.sequence_length = sequence_length + self.add_start_token = add_start_token + self.add_end_token = add_end_token + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.bos_token_id, + end_value=self.tokenizer.eos_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "BLOOM requires each input feature to contain only " + f"one segment, but received {len(x)}. If you are using BLOOM " + "for a multi-segment classification task, please refer to " + "classification models like BERT or RoBERTa." + ) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config + + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @classproperty + def tokenizer_cls(cls): + return BloomTokenizer diff --git a/keras_nlp/models/bloom/bloom_preprocessor_test.py b/keras_nlp/models/bloom/bloom_preprocessor_test.py new file mode 100644 index 0000000000..938113ef4b --- /dev/null +++ b/keras_nlp/models/bloom/bloom_preprocessor_test.py @@ -0,0 +1,80 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.models.bloom.bloom_preprocessor import BloomPreprocessor +from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_nlp.tests.test_case import TestCase + + +class BloomPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["", "", ""] + self.vocab += ["!", "air", "Ä air", "plane", "Ä at", "port"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ä  a", "Ä  t", "Ä  i", "Ä  b", "a i", "p l", "n e"] + self.merges += ["Ä a t", "p o", "r t", "Ä t h", "ai r", "pl a", "po rt"] + self.merges += ["Ä ai r", "Ä a i", "pla ne"] + self.tokenizer = BloomTokenizer( + vocabulary=self.vocab, + merges=self.merges, + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["airplane at airport"] + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=BloomPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output={ + "token_ids": [[1, 4, 6, 7, 5, 8, 2, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]], + }, + ) + + def test_no_start_end_token(self): + input_data = ["airplane at airport"] * 4 + + preprocessor = BloomPreprocessor( + tokenizer=BloomTokenizer( + vocabulary=self.vocab, + merges=self.merges, + ), + sequence_length=8, + add_start_token=False, + add_end_token=False, + ) + x = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[4, 6, 7, 5, 8, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4) + + def test_sequence_length_override(self): + input_data = "airplane at airport" + preprocessor = BloomPreprocessor(**self.init_kwargs) + x = preprocessor(input_data, sequence_length=4) + self.assertAllEqual(x["token_ids"], [1, 4, 6, 2]) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in BloomPreprocessor.presets: + self.run_preset_test( + cls=BloomPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/bloom/bloom_presets.py b/keras_nlp/models/bloom/bloom_presets.py new file mode 100644 index 0000000000..d3e9c780c0 --- /dev/null +++ b/keras_nlp/models/bloom/bloom_presets.py @@ -0,0 +1,30 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BLOOM model preset configurations.""" + +backbone_presets = { + "bloom_560m_multi": { + "metadata": { + "description": ( + "24-layer Bloom model. trained on 45 natural languages and " + "12 programming languages." + ), + "params": 816115712, + "official_name": "BLOOM", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloom", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloom_560m_multi/1", + }, +} diff --git a/keras_nlp/models/bloom/bloom_tokenizer.py b/keras_nlp/models/bloom/bloom_tokenizer.py new file mode 100644 index 0000000000..6ab26c6353 --- /dev/null +++ b/keras_nlp/models/bloom/bloom_tokenizer.py @@ -0,0 +1,123 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.backend import keras +from keras_nlp.models.bloom.bloom_presets import backbone_presets +from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer +from keras_nlp.utils.python_utils import classproperty + + +@keras.saving.register_keras_serializable(package="keras_nlp") +class BloomTokenizer(BytePairTokenizer): + """A BLOOM tokenizer using Byte-Pair Encoding subword segmentation. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_nlp.tokenizers.BytePairTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by BLOOM + models and provides a `from_preset()` method to automatically download + a matching vocabulary for a BLOOM preset. + + This tokenizer does not provide truncation or padding of inputs. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + vocabulary: string or dict, maps token to integer ids. If it is a + string, it should be the file path to a json file. + merges: string or list, contains the merge rule. If it is a string, + it should be the file path to merge rules. The merge rule file + should have one merge rule per line. Every merge rule contains + merge entities separated by a space. + + Examples: + + ```python + # Unbatched input. + tokenizer = keras_nlp.models.BloomTokenizer.from_preset("bloom_560m_multi") + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + + # Custom vocabulary. + vocab = {"": 0, "": 1, "": 2, "a": 3, "Ä quick": 4, "Ä fox": 5} + merges = ["Ä  q", "u i", "c k", "ui ck", "Ä q uick"] + merges += ["Ä  f", "o x", "Ä f ox"] + tokenizer = keras_nlp.models.BloomTokenizer(vocabulary=vocab, merges=merges) + tokenizer("a quick fox.") + ``` + """ + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + self.bos_token = "" + self.eos_token = "" + self.pad_token = "" + + super().__init__( + vocabulary=vocabulary, + merges=merges, + unsplittable_tokens=[ + self.bos_token, + self.eos_token, + self.pad_token, + ], + **kwargs, + ) + + def set_vocabulary_and_merges(self, vocabulary, merges): + super().set_vocabulary_and_merges(vocabulary, merges) + + if vocabulary is not None: + # Check for necessary special tokens. + for token in [self.bos_token, self.eos_token, self.pad_token]: + if token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in " + "your `vocabulary` or use a pretrained `vocabulary` name." + ) + + self.bos_token_id = self.token_to_id(self.bos_token) + self.eos_token_id = self.token_to_id(self.eos_token) + self.pad_token_id = self.token_to_id(self.pad_token) + else: + self.bos_token_id = None + self.eos_token_id = None + self.pad_token_id = None + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + def get_config(self): + config = super().get_config() + # In the constructor, we pass the list of special tokens to the + # `unsplittable_tokens` arg of the superclass' constructor. Hence, we + # delete it from the config here. + del config["unsplittable_tokens"] + return config diff --git a/keras_nlp/models/bloom/bloom_tokenizer_test.py b/keras_nlp/models/bloom/bloom_tokenizer_test.py new file mode 100644 index 0000000000..9ae9c0cc00 --- /dev/null +++ b/keras_nlp/models/bloom/bloom_tokenizer_test.py @@ -0,0 +1,63 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_nlp.tests.test_case import TestCase + + +class BloomTokenizerTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ä air", "plane", "Ä at", "port"] + self.vocab += ["", "", ""] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ä  a", "Ä  t", "Ä  i", "Ä  b", "a i", "p l", "n e"] + self.merges += ["Ä a t", "p o", "r t", "Ä t h", "ai r", "pl a", "po rt"] + self.merges += ["Ä ai r", "Ä a i", "pla ne"] + self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.input_data = [ + "airplane at airport", + " airplane airport", + ] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=BloomTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[6, 1, 3, 4, 2, 5, 8], [6, 2, 3, 2, 5, 8]], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + BloomTokenizer(vocabulary=["a", "b", "c"], merges=[]) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=BloomTokenizer, + preset="bloom_560m_multi", + input_data=["The quick brown fox."], + expected_output=[[2175, 23714, 73173, 144252, 17]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in BloomTokenizer.presets: + self.run_preset_test( + cls=BloomTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py index aa5077ec67..e7bd8ca20a 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py @@ -67,6 +67,10 @@ class DebertaV3Backbone(Backbone): `max_sequence_length`. bucket_size: int. The size of the relative position buckets. Generally equal to `max_sequence_length // 2`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. Example: ```python @@ -106,48 +110,38 @@ def __init__( dropout=0.1, max_sequence_length=512, bucket_size=256, + dtype=None, **kwargs, ): - # Inputs - token_id_input = keras.Input( - shape=(None,), dtype="int32", name="token_ids" - ) - padding_mask = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" - ) - - # Embed tokens. - token_embedding_layer = ReversibleEmbedding( + # === Layers === + self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, embeddings_initializer=deberta_kernel_initializer(), + dtype=dtype, name="token_embedding", ) - x = token_embedding_layer(token_id_input) - - # Normalize and apply dropout to embeddings. - x = keras.layers.LayerNormalization( + self.embeddings_layer_norm = keras.layers.LayerNormalization( epsilon=1e-7, - dtype="float32", + dtype=dtype, name="embeddings_layer_norm", - )(x) - x = keras.layers.Dropout( + ) + self.embeddings_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="embeddings_dropout", - )(x) - - # Relative embedding layer. - rel_embeddings = RelativeEmbedding( + ) + self.relative_embeddings = RelativeEmbedding( hidden_dim=hidden_dim, bucket_size=bucket_size, layer_norm_epsilon=1e-7, kernel_initializer=deberta_kernel_initializer(), + dtype=dtype, name="rel_embedding", - )(x) - - # Apply successive DeBERTa encoder blocks. + ) + self.transformer_layers = [] for i in range(num_layers): - x = DisentangledAttentionEncoder( + layer = DisentangledAttentionEncoder( num_heads=num_heads, intermediate_dim=intermediate_dim, max_position_embeddings=max_sequence_length, @@ -156,23 +150,38 @@ def __init__( activation=keras.activations.gelu, layer_norm_epsilon=1e-7, kernel_initializer=deberta_kernel_initializer(), + dtype=dtype, name=f"disentangled_attention_encoder_layer_{i}", - )( + ) + self.transformer_layers.append(layer) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + x = self.embeddings_layer_norm(x) + x = self.embeddings_dropout(x) + rel_embeddings = self.relative_embeddings(x) + for transformer_layer in self.transformer_layers: + x = transformer_layer( x, rel_embeddings=rel_embeddings, - padding_mask=padding_mask, + padding_mask=padding_mask_input, ) - - # Instantiate using Functional API Model constructor super().__init__( inputs={ "token_ids": token_id_input, - "padding_mask": padding_mask, + "padding_mask": padding_mask_input, }, outputs=x, **kwargs, ) - # All references to `self` below this line + + # === Config === self.vocabulary_size = vocabulary_size self.num_layers = num_layers self.num_heads = num_heads @@ -182,7 +191,6 @@ def __init__( self.max_sequence_length = max_sequence_length self.bucket_size = bucket_size self.start_token_index = 0 - self.token_embedding = token_embedding_layer def get_config(self): config = super().get_config() diff --git a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py index b03122064d..d6eea63601 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py @@ -163,32 +163,48 @@ def __init__( dropout=0.0, **kwargs, ): - inputs = backbone.input + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.pooled_dropout = keras.layers.Dropout( + dropout, + dtype=backbone.dtype_policy, + name="pooled_dropout", + ) hidden_dim = hidden_dim or backbone.hidden_dim - - x = backbone(inputs)[:, backbone.start_token_index, :] - x = keras.layers.Dropout(dropout, name="pooled_dropout")(x) - x = keras.layers.Dense( + self.pooled_dense = keras.layers.Dense( hidden_dim, activation=keras.activations.gelu, + dtype=backbone.dtype_policy, name="pooled_dense", - )(x) - x = keras.layers.Dropout(backbone.dropout, name="classifier_dropout")(x) - outputs = keras.layers.Dense( + ) + self.output_dropout = keras.layers.Dropout( + backbone.dropout, + dtype=backbone.dtype_policy, + name="classifier_dropout", + ) + self.output_dense = keras.layers.Dense( num_classes, kernel_initializer=deberta_kernel_initializer(), activation=activation, + dtype=backbone.dtype_policy, name="logits", - )(x) + ) - # Instantiate using Functional API Model constructor + # === Functional Model === + inputs = backbone.input + x = backbone(inputs)[:, backbone.start_token_index, :] + x = self.pooled_dropout(x) + x = self.pooled_dense(x) + x = self.output_dropout(x) + outputs = self.output_dense(x) super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - # All references to `self` below this line + + # === Config === self.backbone = backbone self.preprocessor = preprocessor self.num_classes = num_classes @@ -196,7 +212,7 @@ def __init__( self.hidden_dim = hidden_dim self.dropout = dropout - # Default compilation + # === Default compilation === logit_output = self.activation == keras.activations.linear self.compile( loss=keras.losses.SparseCategoricalCrossentropy( diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py index bf6a850a54..d050dde6c0 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py @@ -104,32 +104,34 @@ def __init__( preprocessor=None, **kwargs, ): - inputs = { - **backbone.input, - "mask_positions": keras.Input( - shape=(None,), dtype="int32", name="mask_positions" - ), - } - backbone_outputs = backbone(backbone.input) - outputs = MaskedLMHead( + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.masked_lm_head = MaskedLMHead( vocabulary_size=backbone.vocabulary_size, token_embedding=backbone.token_embedding, intermediate_activation=keras.activations.gelu, kernel_initializer=deberta_kernel_initializer(), + dtype=backbone.dtype_policy, name="mlm_head", - )(backbone_outputs, inputs["mask_positions"]) + ) - # Instantiate using Functional API Model constructor + # === Functional Model === + inputs = { + **backbone.input, + "mask_positions": keras.Input( + shape=(None,), dtype="int32", name="mask_positions" + ), + } + x = backbone(backbone.input) + outputs = self.masked_lm_head(x, inputs["mask_positions"]) super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - # All references to `self` below this line - self.backbone = backbone - self.preprocessor = preprocessor + # === Default compilation === self.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(5e-5), diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py index 217980ea59..f041a6f7ff 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py @@ -43,7 +43,7 @@ def setUp(self): self.input_data = ["the quick brown fox"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=DebertaV3MaskedLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py index 93f4fbbd22..88fa08fd70 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py @@ -156,9 +156,9 @@ def __init__( ): super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.truncate = truncate self.sequence_length = sequence_length - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -192,6 +192,17 @@ def call(self, x, y=None, sample_weight=None): } return pack_x_y_sample_weight(x, y, sample_weight) + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return DebertaV3Tokenizer diff --git a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py index a50022f3c7..a9e2a59c29 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py @@ -42,7 +42,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=DebertaV3Preprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/deberta_v3/disentangled_attention_encoder.py b/keras_nlp/models/deberta_v3/disentangled_attention_encoder.py index 081c345c79..be79bc98ff 100644 --- a/keras_nlp/models/deberta_v3/disentangled_attention_encoder.py +++ b/keras_nlp/models/deberta_v3/disentangled_attention_encoder.py @@ -99,22 +99,26 @@ def build(self, inputs_shape): dropout=self.dropout, kernel_initializer=clone_initializer(self.kernel_initializer), bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, name="self_attention_layer", ) self._self_attention_layer.build(inputs_shape) self._self_attention_layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, name="self_attention_layer_norm", ) self._self_attention_layer_norm.build(inputs_shape) self._self_attention_dropout = keras.layers.Dropout( rate=self.dropout, + dtype=self.dtype_policy, name="self_attention_dropout", ) # Feedforward layers. self._feedforward_layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, name="feedforward_layer_norm", ) self._feedforward_layer_norm.build(inputs_shape) @@ -123,6 +127,7 @@ def build(self, inputs_shape): activation=self.activation, kernel_initializer=clone_initializer(self.kernel_initializer), bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, name="feedforward_intermediate_dense", ) self._feedforward_intermediate_dense.build(inputs_shape) @@ -130,6 +135,7 @@ def build(self, inputs_shape): hidden_dim, kernel_initializer=clone_initializer(self.kernel_initializer), bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, name="feedforward_output_dense", ) intermediate_shape = list(inputs_shape) @@ -137,6 +143,7 @@ def build(self, inputs_shape): self._feedforward_output_dense.build(tuple(intermediate_shape)) self._feedforward_dropout = keras.layers.Dropout( rate=self.dropout, + dtype=self.dtype_policy, name="feedforward_dropout", ) self.built = True diff --git a/keras_nlp/models/deberta_v3/disentangled_self_attention.py b/keras_nlp/models/deberta_v3/disentangled_self_attention.py index 1c9ae569c7..48730f2ad6 100644 --- a/keras_nlp/models/deberta_v3/disentangled_self_attention.py +++ b/keras_nlp/models/deberta_v3/disentangled_self_attention.py @@ -86,6 +86,7 @@ def build(self, inputs_shape, rel_embeddings_shape=None): output_shape=(None, self.num_heads, self.attn_head_size), bias_axes="de", **self._get_common_kwargs_for_sublayer(use_bias=True), + dtype=self.dtype_policy, name="query", ) self._query_dense.build(inputs_shape) @@ -94,6 +95,7 @@ def build(self, inputs_shape, rel_embeddings_shape=None): output_shape=(None, self.num_heads, self.attn_head_size), bias_axes="de", **self._get_common_kwargs_for_sublayer(use_bias=True), + dtype=self.dtype_policy, name="key", ) self._key_dense.build(inputs_shape) @@ -102,17 +104,27 @@ def build(self, inputs_shape, rel_embeddings_shape=None): output_shape=(None, self.num_heads, self.attn_head_size), bias_axes="de", **self._get_common_kwargs_for_sublayer(use_bias=True), + dtype=self.dtype_policy, name="value", ) self._value_dense.build(inputs_shape) # Relative attention. - self._position_dropout_layer = keras.layers.Dropout(self.dropout) + self._position_dropout_layer = keras.layers.Dropout( + self.dropout, + dtype=self.dtype_policy, + ) self._attn_dropout_layer = keras.layers.Dropout( - self.dropout, name="attention_dropout" + self.dropout, + dtype=self.dtype_policy, + name="attention_dropout", + ) + self._softmax = keras.layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", ) - self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax") # Output. self._output_dense = keras.layers.EinsumDense( @@ -120,6 +132,7 @@ def build(self, inputs_shape, rel_embeddings_shape=None): output_shape=(None, self.hidden_dim), bias_axes="d", **self._get_common_kwargs_for_sublayer(use_bias=True), + dtype=self.dtype_policy, name="attention_output", ) self._output_dense.build(inputs_shape) diff --git a/keras_nlp/models/deberta_v3/relative_embedding.py b/keras_nlp/models/deberta_v3/relative_embedding.py index 6ae29a5fd7..f727ce0568 100644 --- a/keras_nlp/models/deberta_v3/relative_embedding.py +++ b/keras_nlp/models/deberta_v3/relative_embedding.py @@ -57,7 +57,9 @@ def __init__( name="rel_embedding", ) self.layer_norm = keras.layers.LayerNormalization( - epsilon=layer_norm_epsilon, name="rel_embeddings_layer_norm" + epsilon=layer_norm_epsilon, + dtype=self.dtype_policy, + name="rel_embeddings_layer_norm", ) def call(self, inputs): diff --git a/keras_nlp/models/distil_bert/distil_bert_backbone.py b/keras_nlp/models/distil_bert/distil_bert_backbone.py index a3634215fa..1ae0840ea8 100644 --- a/keras_nlp/models/distil_bert/distil_bert_backbone.py +++ b/keras_nlp/models/distil_bert/distil_bert_backbone.py @@ -62,6 +62,10 @@ class DistilBertBackbone(Backbone): can consume. If None, `max_sequence_length` uses the value from sequence length. This determines the variable shape for positional embeddings. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. Examples: ```python @@ -98,60 +102,67 @@ def __init__( intermediate_dim, dropout=0.1, max_sequence_length=512, + dtype=None, **kwargs, ): - # Inputs - token_id_input = keras.Input( - shape=(None,), dtype="int32", name="token_ids" - ) - padding_mask = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" - ) - - # Embed tokens and positions. - embedding_layer = TokenAndPositionEmbedding( + # === Layers === + self.embeddings = TokenAndPositionEmbedding( vocabulary_size=vocabulary_size, sequence_length=max_sequence_length, embedding_dim=hidden_dim, embeddings_initializer=distilbert_kernel_initializer(), + dtype=dtype, name="token_and_position_embedding", ) - x = embedding_layer(token_id_input) - - # Normalize and apply dropout to embeddings. - x = keras.layers.LayerNormalization( + # Keep the token_embedding property for consistency across models. + self.token_embedding = self.embeddings.token_embedding + self.embeddings_layer_norm = keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, - dtype="float32", + dtype=dtype, name="embeddings_layer_norm", - )(x) - x = keras.layers.Dropout( + ) + self.embeddings_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="embeddings_dropout", - )(x) - - # Apply successive transformer encoder blocks. + ) + self.transformer_layers = [] for i in range(num_layers): - x = TransformerEncoder( + layer = TransformerEncoder( num_heads=num_heads, intermediate_dim=intermediate_dim, activation="gelu", dropout=dropout, layer_norm_epsilon=1e-12, kernel_initializer=distilbert_kernel_initializer(), + dtype=dtype, name=f"transformer_layer_{i}", - )(x, padding_mask=padding_mask) + ) + self.transformer_layers.append(layer) - # Instantiate using Functional API Model constructor + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.embeddings(token_id_input) + x = self.embeddings_layer_norm(x) + x = self.embeddings_dropout(x) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, padding_mask=padding_mask_input) super().__init__( inputs={ "token_ids": token_id_input, - "padding_mask": padding_mask, + "padding_mask": padding_mask_input, }, outputs=x, **kwargs, ) - # All references to `self` below this line + + # === Config === self.vocabulary_size = vocabulary_size self.num_layers = num_layers self.num_heads = num_heads @@ -160,7 +171,6 @@ def __init__( self.dropout = dropout self.max_sequence_length = max_sequence_length self.cls_token_index = 0 - self.token_embedding = embedding_layer.token_embedding def get_config(self): config = super().get_config() diff --git a/keras_nlp/models/distil_bert/distil_bert_classifier.py b/keras_nlp/models/distil_bert/distil_bert_classifier.py index 42de1cee83..e82aaf2781 100644 --- a/keras_nlp/models/distil_bert/distil_bert_classifier.py +++ b/keras_nlp/models/distil_bert/distil_bert_classifier.py @@ -150,39 +150,49 @@ def __init__( dropout=0.2, **kwargs, ): - inputs = backbone.input + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor hidden_dim = hidden_dim or backbone.hidden_dim - - x = backbone(inputs)[:, backbone.cls_token_index, :] - x = keras.layers.Dense( + self.pooled_dense = keras.layers.Dense( hidden_dim, activation="relu", kernel_initializer=distilbert_kernel_initializer(), + dtype=backbone.dtype_policy, name="pooled_dense", - )(x) - x = keras.layers.Dropout(dropout, name="classifier_dropout")(x) - outputs = keras.layers.Dense( + ) + self.output_dropout = keras.layers.Dropout( + dropout, + dtype=backbone.dtype_policy, + name="output_dropout", + ) + self.output_dense = keras.layers.Dense( num_classes, kernel_initializer=distilbert_kernel_initializer(), activation=activation, + dtype=backbone.dtype_policy, name="logits", - )(x) + ) - # Instantiate using Functional API Model constructor + # === Functional Model === + inputs = backbone.input + x = backbone(inputs)[:, backbone.cls_token_index, :] + x = self.pooled_dense(x) + x = self.output_dropout(x) + outputs = self.output_dense(x) super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - # All references to `self` below this line - self.backbone = backbone - self.preprocessor = preprocessor + + # === Config === self.num_classes = num_classes self.activation = keras.activations.get(activation) self.hidden_dim = hidden_dim self.dropout = dropout + # === Default compilation === logit_output = self.activation == keras.activations.linear self.compile( loss=keras.losses.SparseCategoricalCrossentropy( diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py index 71cb117d5b..fcf54e014d 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py @@ -104,6 +104,19 @@ def __init__( preprocessor=None, **kwargs, ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.masked_lm_head = MaskedLMHead( + vocabulary_size=backbone.vocabulary_size, + token_embedding=backbone.token_embedding, + intermediate_activation="gelu", + kernel_initializer=distilbert_kernel_initializer(), + dtype=backbone.dtype_policy, + name="mlm_head", + ) + + # === Functional Model === inputs = { **backbone.input, "mask_positions": keras.Input( @@ -111,25 +124,16 @@ def __init__( ), } backbone_outputs = backbone(backbone.input) - outputs = MaskedLMHead( - vocabulary_size=backbone.vocabulary_size, - token_embedding=backbone.token_embedding, - intermediate_activation="gelu", - kernel_initializer=distilbert_kernel_initializer(), - name="mlm_head", - )(backbone_outputs, inputs["mask_positions"]) - - # Instantiate using Functional API Model constructor + outputs = self.masked_lm_head( + backbone_outputs, inputs["mask_positions"] + ) super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - # All references to `self` below this line - self.backbone = backbone - self.preprocessor = preprocessor + # === Default compilation === self.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(5e-5), diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py index b01b1da8ac..85ee5bdd43 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py @@ -41,7 +41,7 @@ def setUp(self): self.input_data = ["the quick brown fox"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=DistilBertMaskedLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/distil_bert/distil_bert_preprocessor.py b/keras_nlp/models/distil_bert/distil_bert_preprocessor.py index 107275f80a..63f4e3637b 100644 --- a/keras_nlp/models/distil_bert/distil_bert_preprocessor.py +++ b/keras_nlp/models/distil_bert/distil_bert_preprocessor.py @@ -127,6 +127,7 @@ def __init__( ): super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.sequence_length = sequence_length self.truncate = truncate @@ -162,6 +163,17 @@ def get_config(self): ) return config + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return DistilBertTokenizer diff --git a/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py b/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py index 22d69c88dc..f58b42cd39 100644 --- a/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py @@ -40,7 +40,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=DistilBertPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/electra/electra_backbone.py b/keras_nlp/models/electra/electra_backbone.py index 66d1db8ccc..0cc0358e82 100644 --- a/keras_nlp/models/electra/electra_backbone.py +++ b/keras_nlp/models/electra/electra_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.position_embedding import PositionEmbedding from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding @@ -25,7 +24,7 @@ def electra_kernel_initializer(stddev=0.02): return keras.initializers.TruncatedNormal(stddev=stddev) -@keras_nlp_export("keras_nlp.models.ElectraBackbone") +@keras.saving.register_keras_serializable(package="keras_nlp") class ElectraBackbone(Backbone): """A Electra encoder network. @@ -58,6 +57,10 @@ class ElectraBackbone(Backbone): can consume. If None, `max_sequence_length` uses the value from sequence length. This determines the variable shape for positional embeddings. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. Examples: ```python @@ -92,92 +95,107 @@ def __init__( dropout=0.1, max_sequence_length=512, num_segments=2, + dtype=None, **kwargs, ): - # Index of classification token in the vocabulary - cls_token_index = 0 - # Inputs - token_id_input = keras.Input( - shape=(None,), dtype="int32", name="token_ids" - ) - segment_id_input = keras.Input( - shape=(None,), dtype="int32", name="segment_ids" - ) - padding_mask = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" - ) - - # Embed tokens, positions, and segment ids. - token_embedding_layer = ReversibleEmbedding( + # === Layers === + self.token_embedding = ReversibleEmbedding( input_dim=vocab_size, output_dim=embedding_dim, embeddings_initializer=electra_kernel_initializer(), + dtype=dtype, name="token_embedding", ) - token_embedding = token_embedding_layer(token_id_input) - position_embedding = PositionEmbedding( + self.position_embedding = PositionEmbedding( initializer=electra_kernel_initializer(), sequence_length=max_sequence_length, + dtype=dtype, name="position_embedding", - )(token_embedding) - segment_embedding = keras.layers.Embedding( + ) + self.segment_embedding = keras.layers.Embedding( input_dim=num_segments, output_dim=embedding_dim, embeddings_initializer=electra_kernel_initializer(), + dtype=dtype, name="segment_embedding", - )(segment_id_input) - - # Add all embeddings together. - x = keras.layers.Add()( - (token_embedding, position_embedding, segment_embedding), ) - # Layer normalization - x = keras.layers.LayerNormalization( - name="embeddings_layer_norm", + self.embeddings_add = keras.layers.Add( + dtype=dtype, + name="embeddings_add", + ) + self.embeddings_layer_norm = keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, - dtype="float32", - )(x) - # Dropout - x = keras.layers.Dropout( + dtype=dtype, + name="embeddings_layer_norm", + ) + self.embeddings_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="embeddings_dropout", - )(x) + ) if hidden_dim != embedding_dim: - x = keras.layers.Dense( + self.embeddings_projection = keras.layers.Dense( hidden_dim, kernel_initializer=electra_kernel_initializer(), + dtype=dtype, name="embeddings_projection", - )(x) - - # Apply successive transformer encoder blocks. + ) + self.transformer_layers = [] for i in range(num_layers): - x = TransformerEncoder( + layer = TransformerEncoder( num_heads=num_heads, intermediate_dim=intermediate_dim, activation=gelu_approximate, dropout=dropout, layer_norm_epsilon=1e-12, kernel_initializer=electra_kernel_initializer(), + dtype=dtype, name=f"transformer_layer_{i}", - )(x, padding_mask=padding_mask) - - sequence_output = x - # Construct the two ELECTRA outputs. The pooled output is a dense layer on - # top of the [CLS] token. - pooled_output = keras.layers.Dense( + ) + self.transformer_layers.append(layer) + self.pooled_dense = keras.layers.Dense( hidden_dim, kernel_initializer=electra_kernel_initializer(), activation="tanh", + dtype=dtype, name="pooled_dense", - )(x[:, cls_token_index, :]) + ) - # Instantiate using Functional API Model constructor + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + segment_id_input = keras.Input( + shape=(None,), dtype="int32", name="segment_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + # Embed tokens, positions, and segment ids. + tokens = self.token_embedding(token_id_input) + positions = self.position_embedding(tokens) + segments = self.segment_embedding(segment_id_input) + # Add all embeddings together. + x = self.embeddings_add((tokens, positions, segments)) + x = self.embeddings_layer_norm(x) + x = self.embeddings_dropout(x) + if hidden_dim != embedding_dim: + x = self.embeddings_projection(x) + # Apply successive transformer encoder blocks. + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, padding_mask=padding_mask_input) + # Index of classification token in the vocabulary + cls_token_index = 0 + sequence_output = x + # Construct the two ELECTRA outputs. The pooled output is a dense layer on + # top of the [CLS] token. + pooled_output = self.pooled_dense(x[:, cls_token_index, :]) super().__init__( inputs={ "token_ids": token_id_input, "segment_ids": segment_id_input, - "padding_mask": padding_mask, + "padding_mask": padding_mask_input, }, outputs={ "sequence_output": sequence_output, @@ -186,7 +204,7 @@ def __init__( **kwargs, ) - # All references to self below this line + # === Config === self.vocab_size = vocab_size self.num_layers = num_layers self.num_heads = num_heads @@ -197,7 +215,6 @@ def __init__( self.max_sequence_length = max_sequence_length self.num_segments = num_segments self.cls_token_index = cls_token_index - self.token_embedding = token_embedding_layer def get_config(self): config = super().get_config() diff --git a/keras_nlp/models/electra/electra_tokenizer.py b/keras_nlp/models/electra/electra_tokenizer.py index acd665c2a3..4fb7829424 100644 --- a/keras_nlp/models/electra/electra_tokenizer.py +++ b/keras_nlp/models/electra/electra_tokenizer.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras from keras_nlp.tokenizers import WordPieceTokenizer -@keras_nlp_export("keras_nlp.models.ElectraTokenizer") +@keras.saving.register_keras_serializable(package="keras_nlp") class ElectraTokenizer(WordPieceTokenizer): """A ELECTRA tokenizer using WordPiece subword segmentation. diff --git a/keras_nlp/models/f_net/f_net_backbone.py b/keras_nlp/models/f_net/f_net_backbone.py index ac4d290b02..309f312a17 100644 --- a/keras_nlp/models/f_net/f_net_backbone.py +++ b/keras_nlp/models/f_net/f_net_backbone.py @@ -66,6 +66,10 @@ class FNetBackbone(Backbone): embeddings. num_segments: int. The number of types that the 'segment_ids' input can take. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. Examples: ```python @@ -99,83 +103,99 @@ def __init__( dropout=0.1, max_sequence_length=512, num_segments=4, + dtype=None, **kwargs, ): - # Index of classification token in the vocabulary - cls_token_index = 0 - # Inputs - token_id_input = keras.Input( - shape=(None,), dtype="int32", name="token_ids" - ) - segment_id_input = keras.Input( - shape=(None,), dtype="int32", name="segment_ids" - ) - - # Embed tokens, positions, and segment ids. - token_embedding_layer = ReversibleEmbedding( + # === Layers === + self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, embeddings_initializer=f_net_kernel_initializer(), + dtype=dtype, name="token_embedding", ) - token_embedding = token_embedding_layer(token_id_input) - position_embedding = PositionEmbedding( + self.position_embedding = PositionEmbedding( initializer=f_net_kernel_initializer(), sequence_length=max_sequence_length, + dtype=dtype, name="position_embedding", - )(token_embedding) - segment_embedding = keras.layers.Embedding( + ) + self.segment_embedding = keras.layers.Embedding( input_dim=num_segments, output_dim=hidden_dim, embeddings_initializer=f_net_kernel_initializer(), + dtype=dtype, name="segment_embedding", - )(segment_id_input) - - # Sum, normalize and apply dropout to embeddings. - x = keras.layers.Add()( - (token_embedding, position_embedding, segment_embedding) ) - x = keras.layers.LayerNormalization( - name="embeddings_layer_norm", + self.embeddings_add = keras.layers.Add( + dtype=dtype, + name="embeddings_add", + ) + self.embeddings_layer_norm = keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, - dtype="float32", - )(x) - - x = keras.layers.Dense( + dtype=dtype, + name="embeddings_layer_norm", + ) + self.embedding_projection = keras.layers.Dense( hidden_dim, kernel_initializer=f_net_kernel_initializer(), bias_initializer=f_net_bias_initializer(), + dtype=dtype, name="embedding_projection", - )(x) - x = keras.layers.Dropout( + ) + self.embeddings_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="embeddings_dropout", - )(x) - - # Apply successive FNet encoder blocks. + ) + self.transformer_layers = [] for i in range(num_layers): - x = FNetEncoder( + layer = FNetEncoder( intermediate_dim=intermediate_dim, activation=gelu_approximate, dropout=dropout, layer_norm_epsilon=1e-12, kernel_initializer=f_net_kernel_initializer(), bias_initializer=f_net_bias_initializer(), + dtype=dtype, name=f"f_net_layer_{i}", - )(x) - - # Construct the two FNet outputs. The pooled output is a dense layer on - # top of the [CLS] token. - sequence_output = x - pooled_output = keras.layers.Dense( + ) + self.transformer_layers.append(layer) + self.pooled_dense = keras.layers.Dense( hidden_dim, kernel_initializer=f_net_kernel_initializer(), bias_initializer=f_net_bias_initializer(), activation="tanh", + dtype=dtype, name="pooled_dense", - )(x[:, cls_token_index, :]) + ) + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + segment_id_input = keras.Input( + shape=(None,), dtype="int32", name="segment_ids" + ) + # Embed tokens, positions, and segment ids. + tokens = self.token_embedding(token_id_input) + positions = self.position_embedding(tokens) + segments = self.segment_embedding(segment_id_input) + # Sum, normalize and apply dropout to embeddings. + x = self.embeddings_add((tokens, positions, segments)) + x = self.embeddings_layer_norm(x) + x = self.embedding_projection(x) + x = self.embeddings_dropout(x) + # Apply successive FNet encoder blocks. + for transformer_layer in self.transformer_layers: + x = transformer_layer(x) + # Index of classification token in the vocabulary + cls_token_index = 0 + # Construct the two FNet outputs. The pooled output is a dense layer on + # top of the [CLS] token. + sequence_output = x + pooled_output = self.pooled_dense(x[:, cls_token_index, :]) # Instantiate using Functional API Model constructor super().__init__( inputs={ @@ -189,7 +209,7 @@ def __init__( **kwargs, ) - # All references to `self` below this line + # === Config === self.vocabulary_size = vocabulary_size self.num_layers = num_layers self.hidden_dim = hidden_dim @@ -198,7 +218,6 @@ def __init__( self.max_sequence_length = max_sequence_length self.num_segments = num_segments self.cls_token_index = cls_token_index - self.token_embedding = token_embedding_layer def get_config(self): config = super().get_config() diff --git a/keras_nlp/models/f_net/f_net_classifier.py b/keras_nlp/models/f_net/f_net_classifier.py index f6485485e1..512182d2cd 100644 --- a/keras_nlp/models/f_net/f_net_classifier.py +++ b/keras_nlp/models/f_net/f_net_classifier.py @@ -109,29 +109,39 @@ def __init__( dropout=0.1, **kwargs, ): - inputs = backbone.input - pooled = backbone(inputs)["pooled_output"] - pooled = keras.layers.Dropout(dropout)(pooled) - outputs = keras.layers.Dense( + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.output_dropout = keras.layers.Dropout( + dropout, + dtype=backbone.dtype_policy, + name="output_dropout", + ) + self.output_dense = keras.layers.Dense( num_classes, kernel_initializer=f_net_kernel_initializer(), activation=activation, + dtype=backbone.dtype_policy, name="logits", - )(pooled) - # Instantiate using Functional API Model constructor + ) + + # === Functional Model === + inputs = backbone.input + pooled = backbone(inputs)["pooled_output"] + pooled = self.output_dropout(pooled) + outputs = self.output_dense(pooled) super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - # All references to `self` below this line - self.backbone = backbone - self.preprocessor = preprocessor + + # === Config === self.num_classes = num_classes self.activation = keras.activations.get(activation) self.dropout = dropout + # === Default compilation === logit_output = self.activation == keras.activations.linear self.compile( loss=keras.losses.SparseCategoricalCrossentropy( diff --git a/keras_nlp/models/f_net/f_net_masked_lm.py b/keras_nlp/models/f_net/f_net_masked_lm.py index d7048cd525..c715a70843 100644 --- a/keras_nlp/models/f_net/f_net_masked_lm.py +++ b/keras_nlp/models/f_net/f_net_masked_lm.py @@ -101,6 +101,19 @@ def __init__( preprocessor=None, **kwargs, ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.masked_lm_head = MaskedLMHead( + vocabulary_size=backbone.vocabulary_size, + token_embedding=backbone.token_embedding, + intermediate_activation="gelu", + kernel_initializer=f_net_kernel_initializer(), + dtype=backbone.dtype_policy, + name="mlm_head", + ) + + # === Functional Model === inputs = { **backbone.input, "mask_positions": keras.Input( @@ -108,24 +121,16 @@ def __init__( ), } backbone_outputs = backbone(backbone.input) - outputs = MaskedLMHead( - vocabulary_size=backbone.vocabulary_size, - token_embedding=backbone.token_embedding, - intermediate_activation="gelu", - kernel_initializer=f_net_kernel_initializer(), - name="mlm_head", - )(backbone_outputs["sequence_output"], inputs["mask_positions"]) - - # Instantiate using Functional API Model constructor + outputs = self.masked_lm_head( + backbone_outputs["sequence_output"], inputs["mask_positions"] + ) super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - # All references to `self` below this line - self.backbone = backbone - self.preprocessor = preprocessor + + # === Default compilation === self.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(5e-5), diff --git a/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py b/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py index 5f72081a0d..7d2ecc0f17 100644 --- a/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py @@ -41,7 +41,7 @@ def setUp(self): self.input_data = ["the quick brown fox"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=FNetMaskedLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/f_net/f_net_preprocessor.py b/keras_nlp/models/f_net/f_net_preprocessor.py index 296493c930..b4cb5836bb 100644 --- a/keras_nlp/models/f_net/f_net_preprocessor.py +++ b/keras_nlp/models/f_net/f_net_preprocessor.py @@ -129,9 +129,9 @@ def __init__( ): super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.truncate = truncate self.sequence_length = sequence_length - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -165,6 +165,17 @@ def call(self, x, y=None, sample_weight=None): } return pack_x_y_sample_weight(x, y, sample_weight) + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return FNetTokenizer diff --git a/keras_nlp/models/f_net/f_net_preprocessor_test.py b/keras_nlp/models/f_net/f_net_preprocessor_test.py index f67737c828..c9096ac59f 100644 --- a/keras_nlp/models/f_net/f_net_preprocessor_test.py +++ b/keras_nlp/models/f_net/f_net_preprocessor_test.py @@ -38,7 +38,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=FNetPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/gemma/__init__.py b/keras_nlp/models/gemma/__init__.py new file mode 100644 index 0000000000..ba0c2545e4 --- /dev/null +++ b/keras_nlp/models/gemma/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/models/gemma/gemma_attention.py b/keras_nlp/models/gemma/gemma_attention.py new file mode 100644 index 0000000000..80c2ac6a63 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_attention.py @@ -0,0 +1,197 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.utils.keras_utils import clone_initializer + + +class CachedGemmaAttention(keras.layers.Layer): + """A cached grouped query attention layer.""" + + def __init__( + self, + head_dim, + num_query_heads, + num_key_value_heads, + kernel_initializer="glorot_uniform", + dropout=0, + **kwargs, + ): + super().__init__(**kwargs) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.dropout = dropout + + self._kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + self.num_key_value_groups = num_query_heads // num_key_value_heads + + def build(self, inputs_shape): + self.hidden_dim = inputs_shape[-1] + + self.query_dense = keras.layers.EinsumDense( + "btd,ndh->btnh", + output_shape=(None, self.num_query_heads, self.head_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="query", + ) + self.query_dense.build(inputs_shape) + + self.key_dense = keras.layers.EinsumDense( + "bsd,kdh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="key", + ) + self.key_dense.build(inputs_shape) + + self.value_dense = keras.layers.EinsumDense( + "bsd,kdh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="value", + ) + self.value_dense.build(inputs_shape) + + self.dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + + self.output_dense = keras.layers.EinsumDense( + equation="btnh,nhd->btd", + output_shape=(None, self.hidden_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + self.output_dense.build( + (None, None, self.num_query_heads, self.head_dim) + ) + self.softmax = keras.layers.Softmax(dtype="float32") + self.built = True + + def _apply_rope(self, x, positions): + """Rope rotate q or k.""" + # TODO: refactor to use RotaryEmbedding layer? + max_wavelength = 10000 + x_shape = ops.shape(x) + freq_exponents = (2.0 / x_shape[-1]) * ops.cast( + ops.arange(x_shape[-1] // 2, dtype="float32"), self.compute_dtype + ) + timescale = max_wavelength**freq_exponents + radians = positions[..., None] / timescale[None, None, :] + radians = radians[..., None, :] + sin, cos = ops.sin(radians), ops.cos(radians) + x1, x2 = ops.split(x, 2, axis=-1) + # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA + # compilation on jax. We should be able to remove this once the + # following PR is in all jax releases we care about: + # https://github.com/openxla/xla/pull/7875 + output = ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) + return ops.reshape(output, x_shape) + + def _compute_attention( + self, + q, + k, + v, + attention_mask, + training=False, + ): + query_normalization = 1 / np.sqrt(self.head_dim) + + q *= ops.cast(query_normalization, dtype=q.dtype) + q_shape = ops.shape(q) + q = ops.reshape( + q, + ( + *q_shape[:-2], + self.num_key_value_heads, + self.num_query_heads // self.num_key_value_heads, + q_shape[-1], + ), + ) + b, q_len, _, _, h = ops.shape(q) + + attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k) + attention_mask = attention_mask[:, None, None, :, :] + orig_dtype = attention_logits.dtype + attention_softmax = self.softmax(attention_logits, mask=attention_mask) + attention_softmax = ops.cast(attention_softmax, orig_dtype) + + if self.dropout: + attention_softmax = self.dropout_layer( + attention_softmax, training=training + ) + + results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v) + return ops.reshape(results, (b, q_len, self.num_query_heads, h)) + + def call( + self, + x, + attention_mask=None, + cache=None, + cache_update_index=0, + training=False, + ): + seq_len = ops.shape(x)[1] + start_index = cache_update_index + positions = ops.cast( + ops.arange(seq_len, dtype="float32"), self.compute_dtype + ) + positions = positions + ops.cast(start_index, self.compute_dtype) + query = self.query_dense(x) + query = self._apply_rope(query, positions) + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + key_update = self.key_dense(x) + key_update = self._apply_rope(key_update, positions) + value_update = self.value_dense(x) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: + key = self.key_dense(x) + key = self._apply_rope(key, positions) + value = self.value_dense(x) + + attention_vec = self._compute_attention( + query, key, value, attention_mask, training=training + ) + + # Wipe attn vec if there are no attended tokens. + no_attended_tokens = ops.all( + ops.equal(attention_mask, 0), axis=-1, keepdims=True + )[..., None] + attention_vec = ops.where( + no_attended_tokens, ops.zeros_like(attention_vec), attention_vec + ) + + attention_output = self.output_dense(attention_vec) + + if cache is not None: + return attention_output, cache + return attention_output diff --git a/keras_nlp/models/gemma/gemma_backbone.py b/keras_nlp/models/gemma/gemma_backbone.py new file mode 100644 index 0000000000..e5814940aa --- /dev/null +++ b/keras_nlp/models/gemma/gemma_backbone.py @@ -0,0 +1,267 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import config +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.gemma.gemma_decoder_block import GemmaDecoderBlock +from keras_nlp.models.gemma.gemma_presets import backbone_presets +from keras_nlp.models.gemma.rms_normalization import RMSNormalization +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.GemmaBackbone") +class GemmaBackbone(Backbone): + """Gemma core network with hyperparameters. + + This backbone implements the base Transformer network for the Gemma model. + It includes the embedding lookups and transformer layers. This backbone + will output the final hidden states for each token, not generative + predictions over the vocabulary space. For a higher-level object for text + generation, see `keras_nlp.models.GemmaCausalLM`. + + The default constructor gives a fully customizable, randomly initialized + Gemma model with any number of layers, heads, and embedding dimensions. To + load preset architectures and weights, use the `from_preset` constructor. + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_query_heads: int. The number of heads for the query projections in + the attention layer. + num_key_value_heads: int. The number of heads for the key and value + projections in the attention layer. + hidden_dim: int. The size of the transformer hidden state at the end + of each transformer layer. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + head_dim: int. The size of each attention head. + layer_norm_epsilon: float. The epsilon value user for every layer norm + in the transformer model. + dropout: float. Dropout probability for the Transformer encoder. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the models computations and weights. Note that some + computations, such as softmax and layer normalization will always + be done a float32 precision regardless of dtype. + + Example usage: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Gemma decoder. + model = keras_nlp.models.GemmaBackbone.from_preset("gemma_2b_en") + model(input_data) + + # Randomly initialized Gemma decoder with custom config. + model = keras_nlp.models.GemmaBackbone( + vocabulary_size=50257, + num_layers=12, + num_query_heads=12, + num_key_value_heads=1, + hidden_dim=768, + intermediate_dim=3072, + head_dim=64, + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + num_key_value_heads, + hidden_dim, + intermediate_dim, + head_dim, + layer_norm_epsilon=1e-6, + dropout=0, + dtype=None, + **kwargs, + ): + if not config.keras_3(): + raise ValueError( + "`GemmaBackbone` requires Keras 3. Run `pip install -U keras` " + "upgrade your Keras version, or see https://keras.io/getting_started/ " + "for more info on Keras versions and installation." + ) + + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=True, + embeddings_initializer=keras.initializers.VarianceScaling( + scale=1.0, + mode="fan_in", + distribution="untruncated_normal", + seed=None, + ), + dtype=dtype, + name="token_embedding", + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = GemmaDecoderBlock( + intermediate_dim=intermediate_dim, + hidden_dim=hidden_dim, + num_query_heads=num_query_heads, + head_dim=head_dim, + num_key_value_heads=num_key_value_heads, + dropout=dropout, + dtype=dtype, + name=f"decoder_block_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = RMSNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="final_normalization", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="float32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="float32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + x = x * ops.cast(ops.sqrt(hidden_dim), x.dtype) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, padding_mask=padding_mask_input) + sequence_output = self.layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.head_dim = head_dim + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "head_dim": self.head_dim, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + } + ) + return config + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @staticmethod + def get_layout_map(device_mesh, model_parallel_dim_name="model"): + """Get a `keras.distribution.LayoutMap` for model parallel distribution. + + The returned `LayoutMap` contains the sharding spec for the gemma + backbone weights, so that you can use it to distribute weights across + the accelerators. + + Sample usage: + ``` + # Feel free to change the mesh shape to balance data and model parallel + mesh = keras.distribution.DeviceMesh( + shape=(1, 8), axis_names=('batch', 'model'), + devices=keras.distribution.list_devices()) + layout_map = GemmaBackbone.get_layout_map( + mesh, model_parallel_dim_name="model") + + distribution = keras.distribution.ModelParallel( + mesh, layout_map, batch_dim_name='batch') + with distribution.scope(): + gemma_model = keras_nlp.models.GemmaCausalLM.from_preset() + ``` + + Args: + device_mesh: The `keras.distribution.DeviceMesh` instance for + distribution. + model_parallel_dim_name: The axis name of the device mesh, where + the weights should be partition on. + Return: + `keras.distribution.LayoutMap` that contains the sharding spec + of all the model weights. + """ + # The weight path and shape of the Gemma backbone is like below (for 2G) + # token_embedding/embeddings, (256128, 2048), 524550144 + # repeat block for decoder + # ... + # decoder_block_17/pre_attention_norm/scale, (2048,), 2048 + # decoder_block_17/attention/query/kernel, (8, 2048, 256), 4194304 + # decoder_block_17/attention/key/kernel, (8, 2048, 256), 4194304 + # decoder_block_17/attention/value/kernel, (8, 2048, 256), 4194304 + # decoder_block_17/attention/attention_output/kernel, (8, 256, 2048), 4194304 + # decoder_block_17/pre_ffw_norm/scale, (2048,), 2048 + # decoder_block_17/ffw_gating/kernel, (2048, 16384), 33554432 + # decoder_block_17/ffw_gating_2/kernel, (2048, 16384), 33554432 + # decoder_block_17/ffw_linear/kernel, (16384, 2048), 33554432 + if not isinstance(device_mesh, keras.distribution.DeviceMesh): + raise ValueError( + "Invalid device_mesh type. Expected `keras.distribution.Device`," + f" got {type(device_mesh)}" + ) + if model_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{model_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + model_dim = model_parallel_dim_name + # The sharding is partition for the hidden_dim of the model. + layout_map = keras.distribution.LayoutMap(device_mesh) + layout_map["token_embedding/embeddings"] = (None, model_dim) + layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ( + None, + model_dim, + None, + ) + layout_map["decoder_block.*attention_output.*kernel"] = ( + None, + None, + model_dim, + ) + layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None) + layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim) + + return layout_map diff --git a/keras_nlp/models/gemma/gemma_backbone_test.py b/keras_nlp/models/gemma/gemma_backbone_test.py new file mode 100644 index 0000000000..c66d318fd5 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_backbone_test.py @@ -0,0 +1,128 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 256128, + "num_layers": 2, + "num_query_heads": 4, + "num_key_value_heads": 4, + "hidden_dim": 128, + "intermediate_dim": 256, + "head_dim": 128, + "layer_norm_epsilon": 1e-6, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=GemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 128), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=GemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=GemmaBackbone, + preset="gemma_2b_en", + input_data={ + "token_ids": ops.array([[651, 4320, 8426, 25341, 235265]]), + "padding_mask": ops.ones((1, 5), dtype="int32"), + }, + expected_output_shape=(1, 5, 2048), + # The forward pass from a preset should be stable! + expected_partial_output=ops.array( + [1.073359, 0.262374, 0.170238, 0.605402, 2.336161] + ), + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaBackbone.presets: + self.run_preset_test( + cls=GemmaBackbone, + preset=preset, + input_data=self.input_data, + ) + + def test_architecture_characteristics(self): + model = GemmaBackbone(**self.init_kwargs) + self.assertEqual(model.count_params(), 33407616) + self.assertEqual(len(model.layers), 6) + + def test_distribution(self): + if keras.backend.backend() != "jax": + return + devices = keras.distribution.list_devices("CPU") + if len(devices) == 1: + # Need more than 1 device for distribution testing. + return + device_mesh = keras.distribution.DeviceMesh( + shape=(1, len(devices)), + axis_names=("batch", "model"), + devices=devices, + ) + + layout_map = GemmaBackbone.get_layout_map(device_mesh) + distribution = keras.distribution.ModelParallel(device_mesh, layout_map) + with distribution.scope(): + model = GemmaBackbone(**self.init_kwargs) + + for w in model.weights: + if "token_embedding/embeddings" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) + if "attention/query/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, "model", None) + ) + if "attention/key/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, "model", None) + ) + if "attention/value/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, "model", None) + ) + if "attention/attention_output/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, None, "model") + ) + if "ffw_gating/kernel" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + if "ffw_gating_2/kernel" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + if "ffw_linearl" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) diff --git a/keras_nlp/models/gemma/gemma_causal_lm.py b/keras_nlp/models/gemma/gemma_causal_lm.py new file mode 100644 index 0000000000..45c7c6abe0 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm.py @@ -0,0 +1,441 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_presets import backbone_presets +from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.GemmaCausalLM") +class GemmaCausalLM(GenerativeTask): + """An end-to-end Gemma model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a Gemma model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_nlp.samplers` objects to control the generation. By + default, `"greedy"` sampling will be used. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to string inputs during + `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default + when creating the model with `from_preset()`. + + Args: + backbone: A `keras_nlp.models.GemmaBackbone` instance. + preprocessor: A `keras_nlp.models.GemmaCausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + + Examples: + + Use `generate()` to do text generation. + ```python + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en") + gemma_lm.generate("I want to say", max_length=30) + + # Generate with batched prompts. + gemma_lm.generate(["This is a", "Where are you"], max_length=30) + ``` + + Compile the `generate()` function with a custom sampler. + ```python + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en") + gemma_lm.compile(sampler="top_k") + gemma_lm.generate("I want to say", max_length=30) + + gemma_lm.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=2)) + gemma_lm.generate("I want to say", max_length=30) + ``` + + Use `generate()` without preprocessing. + ```python + prompt = { + # Token ids for " Keras is". + "token_ids": np.array([[2, 214064, 603, 0, 0, 0, 0]] * 2), + # Use `"padding_mask"` to indicate values that should not be overridden. + "padding_mask": np.array([[1, 1, 1, 0, 0, 0, 0]] * 2), + } + + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset( + "gemma_2b_en", + preprocessor=None, + ) + gemma_lm.generate(prompt) + ``` + + Call `fit()` on a single batch. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en") + gemma_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` without preprocessing. + ```python + x = { + # Token ids for " Keras is deep learning library" + "token_ids": np.array([[2, 214064, 603, 5271, 6044, 9581, 1, 0]] * 2), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 0]] * 2), + } + y = np.array([[214064, 603, 5271, 6044, 9581, 3, 0, 0]] * 2) + sw = np.array([[1, 1, 1, 1, 1, 1, 0, 0]] * 2) + + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset( + "gemma_2b_en", + preprocessor=None, + ) + gemma_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2) + ``` + + Custom backbone and vocabulary. + ```python + tokenizer = keras_nlp.models.GemmaTokenizer( + proto="proto.spm", + ) + preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor( + tokenizer=tokenizer, + sequence_length=128, + ) + backbone = keras_nlp.models.GemmaBackbone( + vocabulary_size=30552, + num_layers=4, + num_heads=4, + hidden_dim=256, + intermediate_dim=512, + max_sequence_length=128, + ) + gemma_lm = keras_nlp.models.GemmaCausalLM( + backbone=backbone, + preprocessor=preprocessor, + ) + gemma_lm.fit(x=features, batch_size=2) + ``` + """ + + def __init__( + self, + backbone, + preprocessor=None, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Default compilation === + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(2e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], + sampler="greedy", + jit_compile=True, + ) + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @classproperty + def backbone_cls(cls): + return GemmaBackbone + + @classproperty + def preprocessor_cls(cls): + return GemmaCausalLMPreprocessor + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `GemmaCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs in the + whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + x = x * ops.cast(ops.sqrt(self.backbone.hidden_dim), x.dtype) + # Each decoder layer has a cache; we update them separately. + caches = [] + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + current_cache = cache[:, i, ...] + x, next_cache = transformer_layer( + x, + cache=current_cache, + cache_update_index=cache_update_index, + ) + caches.append(next_cache) + cache = ops.stack(caches, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.head_dim + shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + end_token_id=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + end_token_id: The id of the end token to stop on. If all + sequences have produced a new `end_token_id`, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self._sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + end_token_id=end_token_id, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if end_token_id is not None: + # Build a mask of `end_token_id` locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = ops.logical_and( + ops.equal(token_ids, end_token_id), + ops.logical_not(padding_mask), + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score. Typically, this tensor captures the output from a call + to `GemmaCausalLM.generate()`, i.e., tokens for both the input + text and the model-generated text. + padding_mask: A [batch_size, num_tokens] tensor indicating the + tokens that should be preserved during generation. This is an + artifact required by the GemmaBackbone and isn't influential on + the computation of this function. If omitted, this function uses + `keras.ops.ones()` to create a tensor of the appropriate shape. + scoring_mode: The type of scores to return, either "logits" or + "loss", both will be per input token. + layer_intercept_fn: An optional function for augmenting activations + with additional computation, for example, as part of + interpretability research. This function will be passed the + activations as its first parameter and a numeric index + associated with that backbone layer. _This index _is not_ an + index into `self.backbone.layers`_. The index -1 accompanies the + embeddings returned by calling `self.backbone.token_embedding()` + on `token_ids` in the forward direction. All subsequent indexes + will be 0-based indices for the activations returned by each of + the Transformers layers in the backbone. This function must + return a [batch_size, num_tokens, hidden_dims] tensor + that can be passed as an input to the next layer in the model. + target_ids: An [batch_size, num_tokens] tensor containing the + predicted tokens against which the loss should be computed. If a + span of tokens is provided (sequential truthy values along + axis=1 in the tensor), the loss will be computed as the + aggregate across those tokens. + + Raises: + ValueError: If an unsupported scoring_mode is provided, or if the + target_ids are not provided when using ScoringMode.LOSS. + + Returns: + The per-token scores as a tensor of size + [batch_size, num_tokens, vocab_size] in "logits" mode, or + [batch_size, num_tokens] in "loss" mode. + + Examples: + + Compute gradients between embeddings and loss scores with TensorFlow: + ```python + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset( + "gemma_2b_en" + ) + generations = gemma_lm.generate( + ["This is a", "Where are you"], + max_length=30 + ) + preprocessed = gemma_lm.preprocessor.generate_preprocess(generations) + generation_ids = preprocessed["token_ids"] + padding_mask = preprocessed["padding_mask"] + target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) + + embeddings = None + with tf.GradientTape(watch_accessed_variables=True) as tape: + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + losses = gemma_lm.score( + token_ids=generation_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + + grads = tape.gradient(losses, embeddings) + ``` + """ + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape) + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + x = token_embeddings * ops.cast( + ops.sqrt(self.backbone.hidden_dim), dtype=self.compute_dtype + ) + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py new file mode 100644 index 0000000000..20c66edff3 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py @@ -0,0 +1,173 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_preprocessor import GemmaPreprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.GemmaCausalLMPreprocessor") +class GemmaCausalLMPreprocessor(GemmaPreprocessor): + """Gemma Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.GemmaCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.GemmaCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.GemmaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor.from_preset( + "gemma_2b_en" + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize a batch of sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Apply tokenization to a `tf.data.Dataset`. + features = tf.constant(["The quick brown fox.", "Call me Ishmael."]) + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Prepare tokens for generation (no end token). + preprocessor.generate_preprocess(["The quick brown fox jumped."]) + + # Map generation outputs back to strings. + preprocessor.generate_postprocess({ + 'token_ids': np.array([[2, 714, 4320, 8426, 25341, 32292, 235265, 0]]), + 'padding_mask': np.array([[ 1, 1, 1, 1, 1, 1, 1, 0]]), + }) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`GemmaCausalLMPreprocessor` generates `y` and `sample_weight` " + "based on your input data, but your data already contains `y` " + "or `sample_weight`. Your `y` and `sample_weight` will be " + "ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Covert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + """Covert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + if not self.built: + self.build(None) + + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + token_ids = ops.convert_to_numpy(token_ids) + mask = ops.convert_to_numpy(padding_mask) + # Also strip any special tokens during detokenization (e.g. the start + # and end markers). In the future we could make this configurable. + mask = mask & (token_ids != self.tokenizer.start_token_id) + mask = mask & (token_ids != self.tokenizer.pad_token_id) + mask = mask & (token_ids != self.tokenizer.end_token_id) + token_ids = tf.ragged.boolean_mask(token_ids, mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..121621da85 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py @@ -0,0 +1,92 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = GemmaTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ), + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["the quick brown fox"] + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=GemmaCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 4, 9, 5, 7, 2, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + [[4, 9, 5, 7, 2, 0, 0, 0]], # Labels shifted. + [[1, 1, 1, 1, 1, 0, 0, 0]], # Zero out unlabeled examples. + ), + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + + preprocessor = GemmaCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[4, 9, 5, 7, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[9, 5, 7, 0, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "the quick brown fox" + preprocessor = GemmaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 4, 9, 5, 7, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 4, 9, 5, 7, 2, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0], + } + preprocessor = GemmaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the quick brown fox") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaCausalLMPreprocessor.presets: + self.run_preset_test( + cls=GemmaCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/gemma/gemma_causal_lm_test.py b/keras_nlp/models/gemma/gemma_causal_lm_test.py new file mode 100644 index 0000000000..0e1d7a14f8 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm_test.py @@ -0,0 +1,245 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import patch + +import keras +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.models.gemma.gemma_causal_lm import GemmaCausalLM +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaCausalLMTest(TestCase): + def setUp(self): + self.tokenizer = GemmaTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ), + ) + self.preprocessor = GemmaCausalLMPreprocessor( + self.tokenizer, + sequence_length=8, + ) + self.backbone = GemmaBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_query_heads=2, + num_key_value_heads=1, + hidden_dim=4, + intermediate_dim=8, + head_dim=2, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = (["the quick brown fox", "the quick brown fox"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=GemmaCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, 11), + ) + + def test_generate(self): + causal_lm = GemmaCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate("the quick brown fox") + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :4], + prompt_ids["token_ids"][:, :4], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :4], + prompt_ids["padding_mask"][:, :4], + ) + + def test_generate_with_bfloat16(self): + original_floatx = keras.config.floatx() + keras.config.set_floatx("float16") + try: + causal_lm = GemmaCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate("the quick brown fox") + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :4], + prompt_ids["token_ids"][:, :4], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :4], + prompt_ids["padding_mask"][:, :4], + ) + finally: + # Restore floatx to the original value to prevent impact on other + # tests even if there is an exception. + keras.config.set_floatx(original_floatx) + + def test_early_stopping(self): + causal_lm = GemmaCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the quick"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = GemmaCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the quick brown fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the quick brown fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=GemmaCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaCausalLM.presets: + self.run_preset_test( + cls=GemmaCausalLM, + preset=preset, + input_data=self.input_data, + ) + + def test_score_logits(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = GemmaCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8, 11) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_loss(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = GemmaCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + target_ids = keras.ops.roll(token_ids, shift=-1, axis=1) + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="loss", + target_ids=target_ids, + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_layer_intercept_fn_exfiltration(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = GemmaCausalLM(**self.init_kwargs) + expected_embedded_shape = (2, 8, 4) + expected_score_shape = (2, 8, 11) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Setup a custom intercept function that extracts the embeddings to a + # a variable from the embeddings layer and otherwise asserts on shapes. + embedded_prompts = None + + def layer_intercept_fn_for_testing(x, i): + if i == -1: + nonlocal embedded_prompts + embedded_prompts = x + else: + nonlocal expected_embedded_shape + self.assertEqual(ops.shape(x), expected_embedded_shape) + return x + + # Get the scores. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + layer_intercept_fn=layer_intercept_fn_for_testing, + ) + + # Assert shapes for info exfiltrated into the parent context. + self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape) + self.assertEqual(ops.shape(scores), expected_score_shape) diff --git a/keras_nlp/models/gemma/gemma_decoder_block.py b/keras_nlp/models/gemma/gemma_decoder_block.py new file mode 100644 index 0000000000..0a91655fc4 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_decoder_block.py @@ -0,0 +1,189 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_nlp.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_nlp.models.gemma.gemma_attention import CachedGemmaAttention +from keras_nlp.models.gemma.rms_normalization import RMSNormalization + + +class GemmaDecoderBlock(keras.layers.Layer): + def __init__( + self, + hidden_dim, + intermediate_dim, + head_dim, + num_query_heads, + num_key_value_heads, + layer_norm_epsilon=1e-6, + dropout=0, + **kwargs, + ): + super().__init__(**kwargs) + + self.intermediate_dim = intermediate_dim + self.hidden_dim = hidden_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + + self.pre_attention_norm = RMSNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="pre_attention_norm", + ) + + self.attention = CachedGemmaAttention( + head_dim=head_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + dropout=dropout, + dtype=self.dtype_policy, + name="attention", + ) + + if self.dropout > 0: + self.attention_dropout = keras.layers.Dropout(rate=dropout) + self.feedforward_dropout = keras.layers.Dropout(rate=dropout) + + self.pre_ffw_norm = RMSNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="pre_ffw_norm", + ) + + self.gating_ffw = keras.layers.EinsumDense( + equation="btd,df->btf", + output_shape=(None, self.intermediate_dim // 2), + dtype=self.dtype_policy, + name="ffw_gating", + ) + + self.gating_ffw_2 = keras.layers.EinsumDense( + equation="btd,df->btf", + output_shape=(None, self.intermediate_dim // 2), + dtype=self.dtype_policy, + name="ffw_gating_2", + ) + + self.ffw_linear = keras.layers.EinsumDense( + equation="btf,fd->btd", + output_shape=(None, self.hidden_dim), + dtype=self.dtype_policy, + name="ffw_linear", + ) + + def build(self, input_shape): + self.pre_attention_norm.build(input_shape) + self.attention.build(input_shape) + + shape = input_shape + self.pre_ffw_norm.build(shape) + self.gating_ffw.build(shape) + self.gating_ffw_2.build(shape) + + shape = self.gating_ffw.compute_output_shape(shape) + self.ffw_linear.build(shape) + self.built = True + + def compute_output_shape(self, input_shape): + # Isometric + return input_shape + + def _compute_attention_mask( + self, x, padding_mask, cache, cache_update_index + ): + decoder_mask = merge_padding_and_attention_mask( + inputs=x, padding_mask=padding_mask, attention_mask=None + ) + batch_size = ops.shape(x)[0] + input_length = output_length = ops.shape(x)[1] + if cache is not None: + input_length = ops.shape(cache)[2] + + causal_mask = compute_causal_mask( + batch_size=batch_size, + input_length=input_length, + output_length=output_length, + cache_index=cache_update_index, + ) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def call( + self, + x, + padding_mask=None, + cache=None, + cache_update_index=0, + ): + normalized_x = self.pre_attention_norm(x) + attention_mask = self._compute_attention_mask( + normalized_x, padding_mask, cache, cache_update_index + ) + if cache is not None: + attention, new_cache = self.attention( + normalized_x, + attention_mask=attention_mask, + cache=cache, + cache_update_index=cache_update_index, + ) + else: + attention = self.attention( + normalized_x, + attention_mask=attention_mask, + ) + + if self.dropout: + attention = self.attention_dropout(attention) + + attention_x = x + attention + normalized_x = self.pre_ffw_norm(attention_x) + + x1 = self.gating_ffw(normalized_x) + x2 = self.gating_ffw_2(normalized_x) + x = keras.activations.gelu(x1, approximate=True) * x2 + x = self.ffw_linear(x) + + x = x + attention_x + + if cache is not None: + return x, new_cache + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "head_dim": self.head_dim, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + } + ) + return config diff --git a/keras_nlp/models/gemma/gemma_lora_test.py b/keras_nlp/models/gemma/gemma_lora_test.py new file mode 100644 index 0000000000..1cbbdfa67f --- /dev/null +++ b/keras_nlp/models/gemma/gemma_lora_test.py @@ -0,0 +1,102 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np +import pytest + +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaLoraTest(TestCase): + def setUp(self): + self._init_kwargs = { + "vocabulary_size": 50, + "num_layers": 2, + "num_query_heads": 2, + "num_key_value_heads": 2, + "hidden_dim": 32, + "intermediate_dim": 16, + "head_dim": 16, + "layer_norm_epsilon": 1e-6, + } + + def test_lora_fine_tuning(self): + # Set up backbone and preprocessor. + backbone = GemmaBackbone(**self._init_kwargs) + backbone.enable_lora(4) + # 4 layers, 2 weights per layer + self.assertLen(backbone.trainable_weights, 4 * 2) + self.assertLen(backbone.non_trainable_weights, 20) + input_data = { + "token_ids": np.ones((2, 5), dtype="int32"), + "padding_mask": np.ones((2, 5), dtype="int32"), + } + targets = np.random.normal(size=(2, 5, self._init_kwargs["hidden_dim"])) + + # Test fine-tuning + backbone.compile(optimizer="sgd", loss="mse") + backbone.fit(input_data, targets, epochs=1) + + # Test saving and reloading. + temp_filepath = os.path.join( + self.get_temp_dir(), "lora_model.weights.h5" + ) + backbone.save_weights(temp_filepath) + new_backbone = GemmaBackbone(**self._init_kwargs) + new_backbone.load_weights(temp_filepath) + ref_out = backbone(input_data) + new_out = new_backbone(input_data) + self.assertAllClose(ref_out, new_out) + + def test_lora_saving_and_reloading(self): + backbone = GemmaBackbone(**self._init_kwargs) + initial_model_filepath = os.path.join( + self.get_temp_dir(), "base.weights.h5" + ) + backbone.save_weights(initial_model_filepath) + + backbone.enable_lora(4) + input_data = { + "token_ids": np.ones((2, 5), dtype="int32"), + "padding_mask": np.ones((2, 5), dtype="int32"), + } + targets = np.random.normal(size=(2, 5, self._init_kwargs["hidden_dim"])) + backbone.compile(optimizer="sgd", loss="mse") + backbone.fit(input_data, targets, epochs=1) + + lora_filepath = os.path.join(self.get_temp_dir(), "lora_model.lora.h5") + backbone.save_lora_weights(lora_filepath) + + # New backbone with same initial weights + new_backbone = GemmaBackbone(**self._init_kwargs) + new_backbone.load_weights(initial_model_filepath) + new_backbone.enable_lora(4) + new_backbone.load_lora_weights(lora_filepath) + + ref_out = backbone(input_data) + new_out = new_backbone(input_data) + self.assertAllClose(ref_out, new_out) + + # Test exceptions + backbone = GemmaBackbone(**self._init_kwargs) + with self.assertRaisesRegex(ValueError, "no lora-enabled layers"): + backbone.save_lora_weights(lora_filepath) + backbone.enable_lora(5) + with self.assertRaisesRegex(ValueError, "ranks must match"): + backbone.load_lora_weights(lora_filepath) + with self.assertRaisesRegex(ValueError, "filename must end in"): + backbone.save_lora_weights("bad_filepath") diff --git a/keras_nlp/models/gemma/gemma_preprocessor.py b/keras_nlp/models/gemma/gemma_preprocessor.py new file mode 100644 index 0000000000..8fc3beb48c --- /dev/null +++ b/keras_nlp/models/gemma/gemma_preprocessor.py @@ -0,0 +1,199 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.models.gemma.gemma_presets import backbone_presets +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.GemmaPreprocessor") +class GemmaPreprocessor(Preprocessor): + """Gemma preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do 2 things: + + - Tokenize the inputs using the `tokenizer`. + - Construct a dictionary with keys `"token_ids"`, `"padding_mask"`, that can + be passed directly to a `keras_nlp.models.GemmaBackbone`. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + The call method of this layer accepts three arguments, `x`, `y`, and + `sample_weight`. `x` can be a python string or tensor representing a single + segment, a list of python strings representing a batch of single segments, + or a list of tensors representing multiple segments to be packed together. + `y` and `sample_weight` are both optional, can have any format, and will be + passed through unaltered. + + `GemmaPreprocessor` expects the input to have only one segment, as Gemma is + mainly used for generation tasks. For tasks having multi-segment inputs + please combine inputs into a single string input before passing to the + preprocessor layer. + + Args: + tokenizer: A `keras_nlp.models.GemmaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + + Directly calling the layer on data. + ```python + preprocessor = keras_nlp.models.GemmaPreprocessor.from_preset( + "gemma_2b_en" + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize a batch of sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Custom vocabulary. + bytes_io = io.BytesIO() + ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."]) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=ds.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=8, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + tokenizer = keras_nlp.models.GemmaTokenizer( + proto=bytes_io.getvalue(), + ) + preprocessor = keras_nlp.models.GemmaPreprocessor(tokenizer=tokenizer) + preprocessor("The quick brown fox jumped.") + ``` + + Apply preprocessing to a `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.GemmaPreprocessor.from_preset( + "gemma_2b_en" + ) + + text = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + label = tf.constant([1, 1]) + + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((text, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(text) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def __init__( + self, + tokenizer, + sequence_length=8192, + add_start_token=True, + add_end_token=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.tokenizer = tokenizer + self.sequence_length = sequence_length + self.add_start_token = add_start_token + self.add_end_token = add_end_token + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "GemmaPreprocessor requires each input to contain only " + f"one segment, but received {len(x)}. If you are using Gemma " + "for a multi-segment classification task, please combine your " + "input into a single string." + ) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @classproperty + def tokenizer_cls(cls): + return GemmaTokenizer diff --git a/keras_nlp/models/gemma/gemma_preprocessor_test.py b/keras_nlp/models/gemma/gemma_preprocessor_test.py new file mode 100644 index 0000000000..f54a509979 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_preprocessor_test.py @@ -0,0 +1,74 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.gemma.gemma_preprocessor import GemmaPreprocessor +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = GemmaTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ), + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["the quick brown fox"] + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=GemmaPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output={ + "token_ids": [[1, 4, 9, 5, 7, 2, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + preprocessor = GemmaPreprocessor( + tokenizer=self.tokenizer, + sequence_length=8, + add_start_token=False, + add_end_token=False, + ) + x = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[4, 9, 5, 7, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + + def test_sequence_length_override(self): + input_data = "the quick brown fox" + preprocessor = GemmaPreprocessor(**self.init_kwargs) + x = preprocessor(input_data, sequence_length=4) + self.assertAllEqual(x["token_ids"], [1, 4, 9, 2]) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaPreprocessor.presets: + self.run_preset_test( + cls=GemmaPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/gemma/gemma_presets.py b/keras_nlp/models/gemma/gemma_presets.py new file mode 100644 index 0000000000..f63fef17fa --- /dev/null +++ b/keras_nlp/models/gemma/gemma_presets.py @@ -0,0 +1,66 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Gemma model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = { + "gemma_2b_en": { + "metadata": { + "description": ( + "18-layer Gemma model (Gemma with 2B parameters). " + ), + "params": 2506172416, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_2b_en/1", + }, + "gemma_instruct_2b_en": { + "metadata": { + "description": ( + "18-layer Gemma model (Gemma with 2B parameters). " + ), + "params": 2506172416, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_2b_en/1", + }, + "gemma_7b_en": { + "metadata": { + "description": ( + "28-layer Gemma model (Gemma with 7B parameters). " + ), + "params": 8537680896, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/1", + }, + "gemma_instruct_7b_en": { + "metadata": { + "description": ( + "28-layer Gemma model (Gemma with 7B parameters). " + ), + "params": 8537680896, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/1", + }, +} diff --git a/keras_nlp/models/gemma/gemma_tokenizer.py b/keras_nlp/models/gemma/gemma_tokenizer.py new file mode 100644 index 0000000000..6a4bb76ea0 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_tokenizer.py @@ -0,0 +1,108 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.gemma.gemma_presets import backbone_presets +from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.GemmaTokenizer") +class GemmaTokenizer(SentencePieceTokenizer): + """Gemma tokenizer layer based on SentencePiece. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_nlp.tokenizers.SentencePieceTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by + Gemma models and provides a `from_preset()` method to automatically + download a matching vocabulary for a Gemma preset. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + proto: Either a `string` path to a SentencePiece proto file, or a + `bytes` object with a serialized SentencePiece proto. See the + [SentencePiece repository](https://github.com/google/sentencepiece) + for more details on the format. + + Examples: + + ```python + # Unbatched input. + tokenizer = keras_nlp.models.GemmaTokenizer.from_preset("gemma_2b_en") + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + + # Custom vocabulary. + bytes_io = io.BytesIO() + ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."]) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=ds.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=8, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + tokenizer = keras_nlp.models.GemmaTokenizer( + proto=bytes_io.getvalue(), + ) + tokenizer("The quick brown fox jumped.") + ``` + """ + + def __init__(self, proto, **kwargs): + self.start_token = "" + self.end_token = "" + self.pad_token = "" + + super().__init__(proto=proto, **kwargs) + + def set_proto(self, proto): + super().set_proto(proto) + if proto is not None: + for token in [self.end_token, self.pad_token]: + if token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in your " + "`vocabulary` or use a pretrained `vocabulary` name." + ) + self.start_token_id = self.token_to_id(self.start_token) + self.end_token_id = self.token_to_id(self.end_token) + self.pad_token_id = self.token_to_id(self.pad_token) + else: + self.start_token_id = None + self.end_token_id = None + self.pad_token_id = None + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/gemma/gemma_tokenizer_test.py b/keras_nlp/models/gemma/gemma_tokenizer_test.py new file mode 100644 index 0000000000..1c617dd937 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_tokenizer_test.py @@ -0,0 +1,67 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaTokenizerTest(TestCase): + def setUp(self): + self.init_kwargs = { + # Generated using create_gemma_test_proto.py + "proto": os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ) + } + self.input_data = ["the quick brown fox", "the earth is round"] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=GemmaTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[4, 9, 5, 7], [4, 6, 8, 10]], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + GemmaTokenizer( + # Generated using create_no_special_token_proto.py + proto=os.path.join( + self.get_test_data_dir(), "no_special_token_vocab.spm" + ) + ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=GemmaTokenizer, + preset="gemma_2b_en", + input_data=["The quick brown fox."], + expected_output=[[651, 4320, 8426, 25341, 235265]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaTokenizer.presets: + self.run_preset_test( + cls=GemmaTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/gemma/rms_normalization.py b/keras_nlp/models/gemma/rms_normalization.py new file mode 100644 index 0000000000..ce9bdaf880 --- /dev/null +++ b/keras_nlp/models/gemma/rms_normalization.py @@ -0,0 +1,40 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.backend import keras +from keras_nlp.backend import ops + + +class RMSNormalization(keras.layers.Layer): + def __init__(self, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + + def build(self, input_shape): + self.scale = self.add_weight( + name="scale", + trainable=True, + shape=(input_shape[-1],), + initializer="zeros", + ) + self.built = True + + def call(self, x): + # Always compute normalization in float32. + x = ops.cast(x, "float32") + scale = ops.cast(self.scale, "float32") + var = ops.mean(ops.square(x), axis=-1, keepdims=True) + normed_inputs = x * ops.reciprocal(ops.sqrt(var + 1e-06)) + normed_inputs = normed_inputs * (1 + scale) + return ops.cast(normed_inputs, self.compute_dtype) diff --git a/keras_nlp/models/generative_task.py b/keras_nlp/models/generative_task.py index 9a461926e4..598217d964 100644 --- a/keras_nlp/models/generative_task.py +++ b/keras_nlp/models/generative_task.py @@ -101,12 +101,7 @@ def compiled_generate_function(inputs, end_token_id, state): for v in self._sampler.variables: new_v = scope.get_current_value(v) sampler_variables.append(new_v if new_v is not None else v) - state = ( - sampler_variables, - trainable_variables, - non_trainable_variables, - ) - return outputs, state + return outputs, sampler_variables def wrapped_generate_function( inputs, @@ -115,18 +110,20 @@ def wrapped_generate_function( # Create an explicit tuple of all variable state. state = ( self._sampler.variables, - self.trainable_variables, - self.non_trainable_variables, + # Use the explicit variable.value to preserve the + # sharding spec of distribution. + [v.value for v in self.trainable_variables], + [v.value for v in self.non_trainable_variables], ) inputs = tree.map_structure(ops.convert_to_tensor, inputs) - outputs, state = compiled_generate_function( + outputs, sampler_variables = compiled_generate_function( inputs, end_token_id, state, ) # Only assign the sampler variables (random seeds), as other # model variables should never be updated in generation. - for ref_v, v in zip(self._sampler.variables, state[0]): + for ref_v, v in zip(self._sampler.variables, sampler_variables): ref_v.assign(v) return outputs diff --git a/keras_nlp/models/gpt2/gpt2_backbone.py b/keras_nlp/models/gpt2/gpt2_backbone.py index 89c23f71de..d93b2199b0 100644 --- a/keras_nlp/models/gpt2/gpt2_backbone.py +++ b/keras_nlp/models/gpt2/gpt2_backbone.py @@ -61,6 +61,10 @@ class GPT2Backbone(Backbone): can consume. If `None`, `max_sequence_length` uses the value from sequence length. This determines the variable shape for positional embeddings. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the models computations and weights. Note that some + computations, such as softmax and layer normalization will always + be done a float32 precision regardless of dtype. Example: ```python @@ -95,70 +99,81 @@ def __init__( intermediate_dim, dropout=0.1, max_sequence_length=1024, + dtype=None, **kwargs, ): - # Inputs - token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") - padding_mask = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" - ) - - # Embed tokens, positions. - token_embedding_layer = ReversibleEmbedding( + # === Layers === + self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, embeddings_initializer=_gpt_2_kernel_initializer(stddev=0.01), + dtype=dtype, name="token_embedding", ) - token_embedding = token_embedding_layer(token_ids) - - # Can't use `TokenAndPositionEmbedding` layer here because of different - # initializers. - position_embedding = PositionEmbedding( + self.position_embedding = PositionEmbedding( initializer=_gpt_2_kernel_initializer(stddev=0.02), sequence_length=max_sequence_length, + dtype=dtype, name="position_embedding", - )(token_embedding) - - # Sum and apply dropout to embeddings. - x = keras.layers.Add(name="embeddings_add")( - (token_embedding, position_embedding) ) - x = keras.layers.Dropout( + self.embeddings_add = keras.layers.Add( + dtype=dtype, + name="embeddings_add", + ) + self.embeddings_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="embeddings_dropout", - )(x) - - # Apply successive transformer decoder blocks. + ) + self.transformer_layers = [] for i in range(num_layers): - x = TransformerDecoder( - intermediate_dim=intermediate_dim, - num_heads=num_heads, - dropout=dropout, - layer_norm_epsilon=1e-05, - activation=gelu_approximate, - kernel_initializer=_gpt_2_kernel_initializer(stddev=0.02), - normalize_first=True, - name=f"transformer_layer_{i}", - )(x, decoder_padding_mask=padding_mask) - - sequence_output = keras.layers.LayerNormalization( - name="layer_norm", + self.transformer_layers.append( + TransformerDecoder( + intermediate_dim=intermediate_dim, + num_heads=num_heads, + dropout=dropout, + layer_norm_epsilon=1e-05, + activation=gelu_approximate, + kernel_initializer=_gpt_2_kernel_initializer(stddev=0.02), + normalize_first=True, + dtype=dtype, + name=f"transformer_layer_{i}", + ) + ) + self.layer_norm = keras.layers.LayerNormalization( axis=-1, epsilon=1e-05, - dtype="float32", - )(x) + dtype=dtype, + name="layer_norm", + ) - # Instantiate using Functional API Model constructor + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + # Embed inputs. + tokens = self.token_embedding(token_id_input) + positions = self.position_embedding(tokens) + x = self.embeddings_add((tokens, positions)) + x = self.embeddings_dropout(x) + # Apply transformer layers. + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, decoder_padding_mask=padding_mask_input) + sequence_output = self.layer_norm(x) + # Instantiate using the Functional constructor. super().__init__( inputs={ - "token_ids": token_ids, - "padding_mask": padding_mask, + "token_ids": token_id_input, + "padding_mask": padding_mask_input, }, outputs=sequence_output, **kwargs, ) - # All references to `self` below this line + + # === Config === self.vocabulary_size = vocabulary_size self.num_layers = num_layers self.num_heads = num_heads @@ -166,7 +181,6 @@ def __init__( self.intermediate_dim = intermediate_dim self.dropout = dropout self.max_sequence_length = max_sequence_length - self.token_embedding = token_embedding_layer def get_config(self): config = super().get_config() diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 44eebd0a20..b0bd529da4 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -155,23 +155,21 @@ def __init__( preprocessor=None, **kwargs, ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === inputs = backbone.input hidden_states = backbone(inputs) outputs = backbone.token_embedding(hidden_states, reverse=True) - - # Instantiate using Functional API Model constructor. super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - self.backbone = backbone - self.preprocessor = preprocessor - self.generate_function = None - self._sampler = None - # Default compilation + # === Default compilation === self.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(2e-5), @@ -216,27 +214,25 @@ def call_with_cache( the final hidden representation of the input tokens, and `cache` is the decoding cache. """ - token_embedding = self.backbone.get_layer("token_embedding")(token_ids) - position_embedding = self.backbone.get_layer("position_embedding")( - token_embedding, start_index=cache_update_index - ) - x = self.backbone.get_layer("embeddings_add")( - (token_embedding, position_embedding) + tokens = self.backbone.token_embedding(token_ids) + positions = self.backbone.position_embedding( + tokens, start_index=cache_update_index ) - x = self.backbone.get_layer("embeddings_dropout")(x) + x = self.backbone.embeddings_add((tokens, positions)) + x = self.backbone.embeddings_dropout(x) # Each decoder layer has a cache; we update them separately. caches = [] - for i in range(self.backbone.num_layers): + for i, transformer_layer in enumerate(self.backbone.transformer_layers): current_cache = cache[:, i, ...] - x, next_cache = self.backbone.get_layer(f"transformer_layer_{i}")( + x, next_cache = transformer_layer( x, self_attention_cache=current_cache, self_attention_cache_update_index=cache_update_index, ) caches.append(next_cache) cache = ops.stack(caches, axis=1) - hidden_states = x = self.backbone.get_layer("layer_norm")(x) - logits = self.backbone.get_layer("token_embedding")(x, reverse=True) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) return logits, hidden_states, cache def _build_cache(self, token_ids): @@ -302,6 +298,7 @@ def next(prompt, cache, index): mask=padding_mask, end_token_id=end_token_id, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py index 400273b792..0623d983a9 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py @@ -40,7 +40,7 @@ def setUp(self): self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=GPT2CausalLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index 29182f77b6..82be34776f 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -118,8 +118,8 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.tokenizer = tokenizer + self.packer = None self.sequence_length = sequence_length self.add_start_token = add_start_token self.add_end_token = add_end_token @@ -175,6 +175,17 @@ def get_config(self): ) return config + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def presets(cls): return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_preprocessor_test.py index d7dcd261ed..35129c200d 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor_test.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor_test.py @@ -38,7 +38,7 @@ def setUp(self): self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=GPT2Preprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py index 40cdc0d5a9..dee8addf14 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_attention.py @@ -64,7 +64,8 @@ def __init__( self.rotary_max_wavelength = rotary_max_wavelength self.rotary_dim = int(self.attn_head_size * rotary_percentage) self.rotary_embedding_layer = RotaryEmbedding( - max_wavelength=rotary_max_wavelength + max_wavelength=rotary_max_wavelength, + dtype=self.dtype_policy, ) self.kernel_initializer = keras.initializers.get(kernel_initializer) self.bias_initializer = keras.initializers.get(bias_initializer) @@ -76,15 +77,22 @@ def build(self, input_shape): output_shape=(None, self.num_heads, 3 * self.attn_head_size), bias_axes="de", **self._get_common_kwargs_for_sublayer(use_bias=True), + dtype=self.dtype_policy, name="query_key_value", ) self._qkv_dense.build(input_shape) self._attn_dropout_layer = keras.layers.Dropout( - self.dropout, name="attention_dropout" + self.dropout, + dtype=self.dtype_policy, + name="attention_dropout", ) - self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax") + self._softmax = keras.layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) # Output. self._output_dense = keras.layers.EinsumDense( @@ -92,6 +100,7 @@ def build(self, input_shape): output_shape=(None, self.hidden_dim), bias_axes="d", **self._get_common_kwargs_for_sublayer(use_bias=True), + dtype=self.dtype_policy, name="attention_output", ) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py index 6804331aed..0ef65a42f1 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone @@ -24,7 +23,7 @@ def _gpt_neo_x_kernel_initializer(stddev=0.02): return keras.initializers.RandomNormal(stddev=stddev) -@keras_nlp_export("keras_nlp.models.GPTNeoXBackbone") +@keras.saving.register_keras_serializable(package="keras_nlp") class GPTNeoXBackbone(Backbone): """GPT-NeoX core network with hyperparameters. @@ -61,6 +60,10 @@ class GPTNeoXBackbone(Backbone): can consume. If `None`, `max_sequence_length` uses the value from sequence length. This determines the variable shape for positional embeddings. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. """ def __init__( @@ -75,31 +78,25 @@ def __init__( rotary_max_wavelength=10000, layer_norm_epsilon=1e-5, max_sequence_length=512, + dtype=None, **kwargs, ): - # Inputs - token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") - padding_mask = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" - ) - - # Embed tokens - token_embedding_layer = ReversibleEmbedding( + # === Layers === + self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, embeddings_initializer=_gpt_neo_x_kernel_initializer(stddev=0.01), + dtype=dtype, name="token_embedding", ) - token_embedding = token_embedding_layer(token_ids) - - x = keras.layers.Dropout( + self.embeddings_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="embeddings_dropout", - )(token_embedding) - - # Apply successive transformer decoder blocks. + ) + self.transformer_layers = [] for i in range(num_layers): - x = GPTNeoXDecoder( + layer = GPTNeoXDecoder( intermediate_dim=intermediate_dim, num_heads=num_heads, dropout=dropout, @@ -109,26 +106,40 @@ def __init__( layer_norm_epsilon=layer_norm_epsilon, activation=gelu_approximate, kernel_initializer=_gpt_neo_x_kernel_initializer(stddev=0.02), + dtype=dtype, name=f"transformer_layer_{i}", - )(x, decoder_padding_mask=padding_mask) - - sequence_output = keras.layers.LayerNormalization( - name="layer_norm", + ) + self.transformer_layers.append(layer) + self.layer_norm = keras.layers.LayerNormalization( axis=-1, epsilon=layer_norm_epsilon, - dtype="float32", - )(x) + dtype=dtype, + name="layer_norm", + ) - # Instantiate using Functional API Model constructor + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + # Embed tokens. + x = self.token_embedding(token_id_input) + x = self.embeddings_dropout(x) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, decoder_padding_mask=padding_mask_input) + sequence_output = self.layer_norm(x) super().__init__( inputs={ - "token_ids": token_ids, - "padding_mask": padding_mask, + "token_ids": token_id_input, + "padding_mask": padding_mask_input, }, outputs=sequence_output, **kwargs, ) - # All references to `self` below this line + + # === Config === self.vocabulary_size = vocabulary_size self.num_layers = num_layers self.num_heads = num_heads @@ -139,7 +150,6 @@ def __init__( self.rotary_max_wavelength = rotary_max_wavelength self.max_sequence_length = max_sequence_length self.layer_norm_epsilon = layer_norm_epsilon - self.token_embedding = token_embedding_layer def get_config(self): config = super().get_config() diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py index 0f813470aa..7725a9f6d8 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.models.generative_task import GenerativeTask @@ -23,7 +22,7 @@ from keras_nlp.utils.python_utils import classproperty -@keras_nlp_export("keras_nlp.models.GPTNeoXCausalLM") +@keras.saving.register_keras_serializable(package="keras_nlp") class GPTNeoXCausalLM(GenerativeTask): """An end-to-end GPTNeoX model for causal language modeling. @@ -52,23 +51,21 @@ def __init__( preprocessor=None, **kwargs, ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === inputs = backbone.input hidden_states = backbone(inputs) outputs = backbone.token_embedding(hidden_states, reverse=True) - - # Instantiate using Functional API Model constructor. super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - self.backbone = backbone - self.preprocessor = preprocessor - self.generate_function = None - self._sampler = None - # Default compilation + # === Default compilation === self.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(2e-5), @@ -109,20 +106,20 @@ def call_with_cache( the final hidden representation of the input tokens, and `cache` is the decoding cache. """ - token_embedding = self.backbone.get_layer("token_embedding")(token_ids) - x = self.backbone.get_layer("embeddings_dropout")(token_embedding) + token_embedding = self.backbone.token_embedding(token_ids) + x = self.backbone.embeddings_dropout(token_embedding) # Each decoder layer has a cache; we update them separately. caches = [] - for i in range(self.backbone.num_layers): + for i, transformer_layer in enumerate(self.backbone.transformer_layers): current_cache = cache[:, i, ...] - x, next_cache = self.backbone.get_layer(f"transformer_layer_{i}")( + x, next_cache = transformer_layer( x, self_attention_cache=current_cache, self_attention_cache_update_index=cache_update_index, ) caches.append(next_cache) cache = ops.stack(caches, axis=1) - x = self.backbone.get_layer("layer_norm")(x) + x = self.backbone.layer_norm(x) hidden_states = x logits = self.backbone.token_embedding(hidden_states, reverse=True) return logits, hidden_states, cache @@ -190,6 +187,7 @@ def next(prompt, cache, index): mask=padding_mask, end_token_id=end_token_id, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py index 92ff9bbb03..665622540e 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py @@ -15,7 +15,7 @@ import tensorflow as tf from absl import logging -from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.models.gpt_neo_x.gpt_neo_x_preprocessor import ( GPTNeoXPreprocessor, @@ -26,7 +26,7 @@ from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -@keras_nlp_export("keras_nlp.models.GPTNeoXCausalLMPreprocessor") +@keras.saving.register_keras_serializable(package="keras_nlp") class GPTNeoXCausalLMPreprocessor(GPTNeoXPreprocessor): """GPT-NeoX Causal LM preprocessor. diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py index f5a7c57421..e873c38c79 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py @@ -40,7 +40,7 @@ def setUp(self): self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=GPTNeoXCausalLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py index ae646fb2b6..0a7bad7cd9 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_decoder.py @@ -99,18 +99,21 @@ def build(self, decoder_sequence_shape): max_sequence_length=self.max_sequence_length, kernel_initializer=clone_initializer(self.kernel_initializer), bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, name="self_attention", ) self._self_attention_layer.build(decoder_sequence_shape) self._self_attention_layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, name="self_attention_layer_norm", ) self._self_attention_layer_norm.build(decoder_sequence_shape) self._self_attention_dropout = keras.layers.Dropout( rate=self.dropout, + dtype=self.dtype_policy, name="self_attention_dropout", ) @@ -120,6 +123,7 @@ def build(self, decoder_sequence_shape): activation=self.activation, kernel_initializer=clone_initializer(self.kernel_initializer), bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, name="feedforward_intermediate_dense", ) self._feedforward_intermediate_dense.build(decoder_sequence_shape) @@ -128,6 +132,7 @@ def build(self, decoder_sequence_shape): hidden_dim, kernel_initializer=clone_initializer(self.kernel_initializer), bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, name="feedforward_output_dense", ) @@ -137,12 +142,14 @@ def build(self, decoder_sequence_shape): self._feedforward_layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, name="feedforward_layer_norm", ) self._feedforward_layer_norm.build(decoder_sequence_shape) self._feedforward_dropout = keras.layers.Dropout( rate=self.dropout, + dtype=self.dtype_policy, name="feedforward_dropout", ) self.built = True @@ -211,9 +218,11 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index, + ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py index 1db4fe4c9b..4e675c9c98 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.models.preprocessor import Preprocessor @@ -23,7 +23,7 @@ from keras_nlp.utils.python_utils import classproperty -@keras_nlp_export("keras_nlp.models.GPTNeoXPreprocessor") +@keras.saving.register_keras_serializable(package="keras_nlp") class GPTNeoXPreprocessor(Preprocessor): """GPTNeoX preprocessing layer which tokenizes and packs inputs. @@ -74,12 +74,11 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.tokenizer = tokenizer + self.packer = None self.sequence_length = sequence_length self.add_start_token = add_start_token self.add_end_token = add_end_token - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -132,6 +131,17 @@ def get_config(self): ) return config + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return GPTNeoXTokenizer diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py index c87329af4a..92ea191596 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py @@ -38,7 +38,7 @@ def setUp(self): self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=GPTNeoXPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer.py index d109c5849d..cc63e99af6 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -@keras_nlp_export("keras_nlp.models.GPTNeoXTokenizer") +@keras.saving.register_keras_serializable(package="keras_nlp") class GPTNeoXTokenizer(BytePairTokenizer): """A GPTNeoX tokenizer using Byte-Pair Encoding subword segmentation. diff --git a/keras_nlp/models/llama/llama_attention.py b/keras_nlp/models/llama/llama_attention.py index a2604e5351..529e73b009 100644 --- a/keras_nlp/models/llama/llama_attention.py +++ b/keras_nlp/models/llama/llama_attention.py @@ -58,6 +58,7 @@ def build(self, inputs_shape): equation="bqm,muh->bquh", output_shape=(None, self.num_query_heads, self.attn_head_size), kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, name="query", ) self._query_dense.build(inputs_shape) @@ -65,6 +66,7 @@ def build(self, inputs_shape): equation="bkm,mvh->bkvh", output_shape=(None, self.num_key_value_heads, self.attn_head_size), kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, name="key", ) self._key_dense.build(inputs_shape) @@ -73,16 +75,22 @@ def build(self, inputs_shape): equation="bkm,mvh->bkvh", output_shape=(None, self.num_key_value_heads, self.attn_head_size), kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, name="value", ) self._value_dense.build(inputs_shape) - self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax") + self._softmax = keras.layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) self._output_dense = keras.layers.EinsumDense( equation="bqm,mh->bqh", output_shape=(None, self.hidden_dim), kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, name="attention_output", ) self._output_dense.build(inputs_shape) @@ -90,6 +98,7 @@ def build(self, inputs_shape): self._rotary_embedding_layer = RotaryEmbedding( max_wavelength=self.rope_max_wavelength, scaling_factor=self.rope_scaling_factor, + dtype=self.dtype_policy, ) self._rotary_embedding_layer.build(inputs_shape) @@ -173,10 +182,10 @@ def _compute_attention(self, query, key, value, attention_mask=None): ) attention_scores /= norm_factor - attention_scores = self._masked_softmax( attention_scores, attention_mask ) + attention_scores = ops.cast(attention_scores, self.compute_dtype) attention_output = ops.einsum( "acbe,aecd->abcd", attention_scores, value ) diff --git a/keras_nlp/models/llama/llama_backbone.py b/keras_nlp/models/llama/llama_backbone.py index 63438544cc..6534fcc0ec 100644 --- a/keras_nlp/models/llama/llama_backbone.py +++ b/keras_nlp/models/llama/llama_backbone.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding @@ -24,7 +23,7 @@ def _llama_kernel_initializer(stddev=0.02): return keras.initializers.RandomNormal(stddev=stddev) -@keras_nlp_export("keras_nlp.models.LlamaBackbone") +@keras.saving.register_keras_serializable(package="keras_nlp") class LlamaBackbone(Backbone): """ LLaMA core network with hyperparameters. @@ -58,7 +57,10 @@ class LlamaBackbone(Backbone): can consume. If `None`, `max_sequence_length` uses the value from sequence length. This determines the variable shape for positional embeddings. - + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. """ def __init__( @@ -73,28 +75,21 @@ def __init__( rope_max_wavelength=10000, layer_norm_epsilon=1e-5, max_sequence_length=4096, + dtype=None, **kwargs, ): - # Inputs - token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") - padding_mask = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" - ) - - # Embed tokens - token_embedding = ReversibleEmbedding( + # === Layers === + self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, embeddings_initializer=_llama_kernel_initializer(stddev=0.01), tie_weights=False, + dtype=dtype, name="token_embedding", - )(token_ids) - - x = token_embedding - - # Apply successive transformer decoder blocks. + ) + self.transformer_layers = [] for i in range(num_layers): - x = LlamaDecoder( + layer = LlamaDecoder( intermediate_dim=intermediate_dim, num_query_heads=num_query_heads, num_key_value_heads=num_key_value_heads, @@ -104,24 +99,37 @@ def __init__( layer_norm_epsilon=layer_norm_epsilon, activation=ops.silu, kernel_initializer=_llama_kernel_initializer(stddev=0.02), + dtype=dtype, name=f"transformer_layer_{i}", - )(x, decoder_padding_mask=padding_mask) - - sequence_output = LlamaLayerNorm( - name="layer_norm", + ) + self.transformer_layers.append(layer) + self.layer_norm = LlamaLayerNorm( + dtype=dtype, epsilon=layer_norm_epsilon, - )(x) + name="layer_norm", + ) - # Instantiate using Functional API Model constructor + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, decoder_padding_mask=padding_mask_input) + sequence_output = self.layer_norm(x) super().__init__( inputs={ - "token_ids": token_ids, - "padding_mask": padding_mask, + "token_ids": token_id_input, + "padding_mask": padding_mask_input, }, outputs=sequence_output, **kwargs, ) - # All references to `self` below this line + + # === Config === self.vocabulary_size = vocabulary_size self.num_layers = num_layers self.num_query_heads = num_query_heads @@ -150,7 +158,3 @@ def get_config(self): } ) return config - - @property - def token_embedding(self): - return self.get_layer("token_embedding") diff --git a/keras_nlp/models/llama/llama_decoder.py b/keras_nlp/models/llama/llama_decoder.py index 47bac478cc..3b9d6906b8 100644 --- a/keras_nlp/models/llama/llama_decoder.py +++ b/keras_nlp/models/llama/llama_decoder.py @@ -64,11 +64,13 @@ def build(self, decoder_sequence_shape): max_sequence_length=self.max_sequence_length, rope_scaling_factor=self.rope_scaling_factor, kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, ) self._self_attention_layer.build(decoder_sequence_shape) self._self_attention_layernorm = LlamaLayerNorm( epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, ) self._self_attention_layernorm.build(decoder_sequence_shape) @@ -76,6 +78,7 @@ def build(self, decoder_sequence_shape): self._feedforward_intermediate_dense = keras.layers.Dense( self.intermediate_dim, kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, ) self._feedforward_intermediate_dense.build(decoder_sequence_shape) @@ -83,12 +86,14 @@ def build(self, decoder_sequence_shape): self.intermediate_dim, activation=self.activation, kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, ) self._feedforward_gate_dense.build(decoder_sequence_shape) self._feedforward_output_dense = keras.layers.Dense( self.hidden_dim, kernel_initializer=clone_initializer(self.kernel_initializer), + dtype=self.dtype_policy, ) intermediate_shape = list(decoder_sequence_shape) @@ -97,6 +102,7 @@ def build(self, decoder_sequence_shape): self._feedforward_layernorm = LlamaLayerNorm( epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, ) self._feedforward_layernorm.build(decoder_sequence_shape) @@ -172,9 +178,11 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index, + ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ), ) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/models/llama/llama_tokenizer.py b/keras_nlp/models/llama/llama_tokenizer.py new file mode 100644 index 0000000000..7acdf8687c --- /dev/null +++ b/keras_nlp/models/llama/llama_tokenizer.py @@ -0,0 +1,81 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer + + +@keras_nlp_export("keras_nlp.models.LlamaTokenizer") +class LlamaTokenizer(SentencePieceTokenizer): + """Llama tokenizer layer based on SentencePiece. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_nlp.tokenizers.SentencePieceTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by + Llama models and provides a `from_preset()` method to automatically + download a matching vocabulary for a Llama preset. + + This tokenizer does not provide truncation or padding of inputs. It can be + combined with a `keras_nlp.models.LlamaPreprocessor` layer for input + packing. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + proto: Either a `string` path to a SentencePiece proto file, or a + `bytes` object with a serialized SentencePiece proto. See the + [SentencePiece repository](https://github.com/google/sentencepiece) + for more details on the format. + + Examples: + ```python + # Unbatched input. + tokenizer = keras_nlp.models.LlamaTokenizer.from_preset( + "llama_7b_en", + ) + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + ``` + """ + + def __init__(self, proto, **kwargs): + self.start_token = "" + self.end_token = "" + super().__init__(proto=proto, **kwargs) + + def set_proto(self, proto): + super().set_proto(proto) + if proto is not None: + for token in [self.start_token, self.end_token]: + if token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in your " + "`vocabulary` or use a pretrained `vocabulary` name." + ) + self.start_token_id = self.token_to_id(self.start_token) + self.end_token_id = self.token_to_id(self.end_token) + self.pad_token_id = 0 + else: + self.start_token_id = None + self.end_token_id = None + self.pad_token_id = None diff --git a/keras_nlp/models/llama/llama_tokenizer_test.py b/keras_nlp/models/llama/llama_tokenizer_test.py new file mode 100644 index 0000000000..9a3c225456 --- /dev/null +++ b/keras_nlp/models/llama/llama_tokenizer_test.py @@ -0,0 +1,46 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.tests.test_case import TestCase + + +class LlamaTokenizerTest(TestCase): + def setUp(self): + self.init_kwargs = { + # Generated using create_llama_test_proto.py + "proto": os.path.join( + self.get_test_data_dir(), "llama_test_vocab.spm" + ) + } + self.input_data = ["the quick brown fox", "the earth is round"] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=LlamaTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[3, 8, 4, 6], [3, 5, 7, 9]], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + LlamaTokenizer( + # Generated using create_no_special_token_proto.py + proto=os.path.join( + self.get_test_data_dir(), "no_special_token_vocab.spm" + ) + ) diff --git a/keras_nlp/models/mistral/mistral_attention.py b/keras_nlp/models/mistral/mistral_attention.py index 680f1f6d1b..6511a22446 100644 --- a/keras_nlp/models/mistral/mistral_attention.py +++ b/keras_nlp/models/mistral/mistral_attention.py @@ -69,7 +69,7 @@ def build(self, inputs_shape): equation="bqm,muh->bquh", output_shape=(None, self._num_query_heads, self._head_dim), kernel_initializer=self._kernel_initializer, - dtype=self.compute_dtype, + dtype=self.dtype_policy, name="query", ) self._query_dense.build(inputs_shape) @@ -82,7 +82,7 @@ def build(self, inputs_shape): self._head_dim, ), kernel_initializer=self._kernel_initializer, - dtype=self.compute_dtype, + dtype=self.dtype_policy, name="key", ) self._key_dense.build(inputs_shape) @@ -95,22 +95,27 @@ def build(self, inputs_shape): self._head_dim, ), kernel_initializer=self._kernel_initializer, - dtype=self.compute_dtype, + dtype=self.dtype_policy, name="value", ) self._value_dense.build(inputs_shape) - self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax") + self._softmax = keras.layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) self._dropout_layer = keras.layers.Dropout( - rate=self._dropout, dtype=self.compute_dtype + rate=self._dropout, + dtype=self.dtype_policy, ) self._output_dense = keras.layers.EinsumDense( equation="bquh,uhm->bqm", output_shape=(None, self._hidden_dim), kernel_initializer=self._kernel_initializer, - dtype=self.compute_dtype, + dtype=self.dtype_policy, name="attention_output", ) self._output_dense.build( @@ -120,7 +125,7 @@ def build(self, inputs_shape): self.rotary_embedding_layer = RotaryEmbedding( max_wavelength=self._rope_max_wavelength, scaling_factor=self._rope_scaling_factor, - dtype=self.compute_dtype, + dtype=self.dtype_policy, ) self._dot_product_equation = "bquh,bkuh->buqk" @@ -136,7 +141,6 @@ def call( cache_update_index=None, training=None, ): - seq_len = ops.shape(hidden_states)[1] start_index = ( cache_update_index if cache_update_index is not None else 0 ) @@ -148,89 +152,34 @@ def call( query = self._query_dense(hidden_states) - # Note that the original PyTorch implementation uses - # view_as_complex/view_as_real while we use split/concatenate to - # convert to/from complex numbers. The transformations below make - # the rope computation numerically equivalent to the original - # implementation. - def _mistral_rope(x): - x = ops.concatenate([x[..., ::2], x[..., 1::2]], axis=-1) - x = self.rotary_embedding_layer(x, start_index=start_index) - x = ops.reshape( - ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x) - ) - return x - # Compute RoPE for queries - query = _mistral_rope(query) + query = self.rotary_embedding_layer(query, start_index=start_index) def _compute_key_value(x): key, value = self._key_dense(x), self._value_dense(x) - key = _mistral_rope(key) + # Compute RoPE for keys + key = self.rotary_embedding_layer(key, start_index=start_index) return key, value if cache is not None: - cache_k = cache[:, 0, ...] - cache_v = cache[:, 1, ...] - + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + key_update, value_update = _compute_key_value(hidden_states) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: if cache_update_index is not None: - # Compute the new keys and values - key, value = _compute_key_value(hidden_states) - - # Cache is a rotating buffer, we want to warp around if - # the sequence length exceeds the sliding window. - update_end_index = ( - cache_update_index + seq_len - 1 - ) % self._sliding_window + 1 - update_end_index = ops.cast(update_end_index, "int32") - cache_update_index = cache_update_index % self._sliding_window - update_start_index = ops.cond( - update_end_index > cache_update_index, - lambda: ops.cast(cache_update_index, "int32"), - lambda: ops.cast(0, "int32"), - ) - # Also note that the update step below assumes that the - # sequence length is always one when `cache_update_index != 0`. - # This is necessary to support XLA compilation. Ideally, we - # would want to use - # `key[:, -(update_end_index - update_start_index):, ...]` - # as the update but updating using a dynamic slice gives an - # XLA compilation error in TensorFlow. - # Passing a sequence of length > 1 with cache update might give - # incorrect results (since there is no way to determine how - # many most recent tokens are to be saved if the tokens exceed - # the sliding window length). - cache_k = ops.slice_update( - cache_k, - [0, update_start_index, 0, 0], - # We slice the keys and values since if the user has passed - # a sequence of length > `self._sliding_window`. We want to - # prefill the cache using just the most recent values in the - # sliding window. - ops.cast( - key[:, -self._sliding_window :, ...], cache_k.dtype - ), + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" ) - cache_v = ops.slice_update( - cache_v, - [0, update_start_index, 0, 0], - ops.cast( - value[:, -self._sliding_window :, ...], cache_v.dtype - ), - ) - cache = ops.stack([cache_k, cache_v], axis=1) - - # Get the required keys and values from the cache. - # Since we expect the user to pass a fixed-size cache, we just - # pick the first few slices up-to and including the newly computed - # keys and values. - cache_k = cache_k[:, :update_end_index, ...] - cache_v = cache_v[:, :update_end_index, ...] - - key = ops.cast(cache_k, dtype=self.compute_dtype) - value = ops.cast(cache_v, dtype=self.compute_dtype) - else: - # Compute keys and values key, value = _compute_key_value(hidden_states) # [batch_shape, seq_len, num_key_value_heads, head_dim] @@ -260,15 +209,15 @@ def _masked_softmax(self, attention_scores, attention_mask=None): return self._softmax(attention_scores) def _compute_attention(self, query, key, value, attention_mask=None): - attention_scores = ops.einsum(self._dot_product_equation, key, query) + attention_scores = ops.einsum(self._dot_product_equation, query, key) norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype)) attention_scores = attention_scores / norm_factor - attention_scores = self._masked_softmax( attention_scores, attention_mask ) + attention_scores = ops.cast(attention_scores, self.compute_dtype) attention_output = ops.einsum( self._combine_equation, attention_scores, value ) diff --git a/keras_nlp/models/mistral/mistral_backbone.py b/keras_nlp/models/mistral/mistral_backbone.py index 42cec8b218..3e2cfae148 100644 --- a/keras_nlp/models/mistral/mistral_backbone.py +++ b/keras_nlp/models/mistral/mistral_backbone.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy + from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops @@ -19,9 +21,11 @@ from keras_nlp.models.mistral.mistral_layer_norm import ( MistralLayerNormalization, ) +from keras_nlp.models.mistral.mistral_presets import backbone_presets from keras_nlp.models.mistral.mistral_transformer_decoder import ( MistralTransformerDecoder, ) +from keras_nlp.utils.python_utils import classproperty def _mistral_kernel_initializer(stddev=0.02): @@ -64,7 +68,10 @@ class MistralBackbone(Backbone): layers in each transformer decoder. Only `sliding_window` number of tokens are saved in the cache and used to generate the next token. Defaults to `512`. - dtype (str, optional): The dtype policy for the mistral model. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. Examples: @@ -107,19 +114,11 @@ def __init__( layer_norm_epsilon=1e-6, sliding_window=512, dropout=0, + dtype=None, **kwargs, ): - # Get the dtype - dtype = kwargs.pop("dtype", keras.backend.floatx()) - - # Inputs - token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") - padding_mask = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" - ) - - # Embed Tokens - token_embedding_layer = ReversibleEmbedding( + # === Layers === + self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, tie_weights=False, @@ -127,11 +126,9 @@ def __init__( dtype=dtype, name="token_embedding", ) - x = token_embedding_layer(token_ids) - - # Apply successive transformer decoder blocks + self.transformer_layers = [] for i in range(num_layers): - x = MistralTransformerDecoder( + layer = MistralTransformerDecoder( intermediate_dim=intermediate_dim, num_query_heads=num_query_heads, num_key_value_heads=num_key_value_heads, @@ -144,25 +141,35 @@ def __init__( dropout=dropout, dtype=dtype, name=f"transformer_layer_{i}", - )(x, decoder_padding_mask=padding_mask) - - sequence_output = MistralLayerNormalization( - name="sequence_output_layernorm", + ) + self.transformer_layers.append(layer) + self.layer_norm = MistralLayerNormalization( epsilon=layer_norm_epsilon, dtype=dtype, - )(x) + name="sequence_output_layernorm", + ) - # Instantiate using Functional API Model constructor + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, decoder_padding_mask=padding_mask_input) + sequence_output = self.layer_norm(x) super().__init__( inputs={ - "token_ids": token_ids, - "padding_mask": padding_mask, + "token_ids": token_id_input, + "padding_mask": padding_mask_input, }, outputs=sequence_output, **kwargs, ) - # All references to `self` below this line + # === Config === self.vocabulary_size = vocabulary_size self.num_layers = num_layers self.num_query_heads = num_query_heads @@ -174,7 +181,6 @@ def __init__( self.sliding_window = sliding_window self.layer_norm_epsilon = layer_norm_epsilon self.dropout = dropout - self.token_embedding = token_embedding_layer def get_config(self): config = super().get_config() @@ -194,3 +200,7 @@ def get_config(self): } ) return config + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/mistral/mistral_backbone_test.py b/keras_nlp/models/mistral/mistral_backbone_test.py index fc2b0a592b..fbfcb91124 100644 --- a/keras_nlp/models/mistral/mistral_backbone_test.py +++ b/keras_nlp/models/mistral/mistral_backbone_test.py @@ -54,3 +54,29 @@ def test_num_parameters(self): model = MistralBackbone(**self.init_kwargs) # Reference value calculated using the PyTorch model self.assertEqual(model.count_params(), 2704) + + @pytest.mark.extra_large + def test_smallest_preset(self): + self.run_preset_test( + cls=MistralBackbone, + preset="mistral_7b_en", + input_data={ + "token_ids": ops.array([[1, 1824, 349, 524, 11234, 28804]]), + "padding_mask": ops.ones((1, 6), dtype="int32"), + }, + expected_output_shape=(1, 6, 4096), + # The forward pass from a preset should be stable! + # Reference values computed using PyTorch HF model. + expected_partial_output=ops.array( + [-1.6875, 0.5117, -1.7188, 2.3125, -0.0996] + ), + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in MistralBackbone.presets: + self.run_preset_test( + cls=MistralBackbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/mistral/mistral_causal_lm.py b/keras_nlp/models/mistral/mistral_causal_lm.py new file mode 100644 index 0000000000..3296bb9495 --- /dev/null +++ b/keras_nlp/models/mistral/mistral_causal_lm.py @@ -0,0 +1,219 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.models.mistral.mistral_backbone import MistralBackbone +from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( + MistralCausalLMPreprocessor, +) +from keras_nlp.models.mistral.mistral_presets import backbone_presets +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.MistralCausalLM") +class MistralCausalLM(GenerativeTask): + """An end-to-end Mistral model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a GPT-NeoX model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_nlp.samplers` objects to control the generation. By + default, `"top_k"` sampling will be used. + + Args: + backbone: A `keras_nlp.models.MistralBackbone` instance. + preprocessor: A `keras_nlp.models.MistralCausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + """ + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.inputs + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Default compilation === + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(2e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], + jit_compile=True, + ) + + @classproperty + def backbone_cls(cls): + return MistralBackbone + + @classproperty + def preprocessor_cls(cls): + return MistralCausalLMPreprocessor + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `MistralCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + # Each decoder layer has a cache; we update them separately. + updated_cache = [] + for i in range(self.backbone.num_layers): + current_cache = cache[:, i, ...] + x, next_cache = self.backbone.transformer_layers[i]( + x, + self_attention_cache=current_cache, + self_attention_cache_update_index=cache_update_index, + ) + updated_cache.append(next_cache) + cache = ops.stack(updated_cache, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads + shape = [ + batch_size, + num_layers, + 2, + max_length, + num_key_value_heads, + head_dim, + ] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + end_token_id=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + end_token_id: The id of the end token to stop on. If all + sequences have produced a new `end_token_id`, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self._sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + end_token_id=end_token_id, + hidden_states=hidden_states, + ) + + # Compute an output padding mask with the token ids we updated. + if end_token_id is not None: + # Build a mask of `end_token_id` locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = ops.logical_and( + ops.equal(token_ids, end_token_id), + ops.logical_not(padding_mask), + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py new file mode 100644 index 0000000000..893036cd58 --- /dev/null +++ b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py @@ -0,0 +1,185 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import ops +from keras_nlp.models.mistral.mistral_preprocessor import MistralPreprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.MistralCausalLMPreprocessor") +class MistralCausalLMPreprocessor(MistralPreprocessor): + """Mistral Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.MistralCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.MistralCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.MistralTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.MistralCausalLMPreprocessor.from_preset( + "mistral_base_en" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("League of legends") + preprocessor(sentence) + # Same output. + preprocessor("League of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) + preprocessor(sentences) + # Same output. + preprocessor(["Taco tuesday", "Fish taco please!"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`MistralCausalLMPreprocessor` generates `y` and " + "`sample_weight` based on your input data, but your data " + "already contains `y` or `sample_weight`. Your `y` and " + "`sample_weight` will be ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Covert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + """Covert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + # Convert the inputs to numpy arrays if they aren't a tensor already. + if not isinstance(token_ids, tf.Tensor): + token_ids = ops.convert_to_numpy(token_ids) + # Make sure the numpy array has type `int32` since + # `SentencePieceProcessor.detokenize` only accepts `int32` arrays. + token_ids = token_ids.astype("int32") + if not isinstance(padding_mask, tf.Tensor): + padding_mask = ops.convert_to_numpy(padding_mask) + padding_mask = padding_mask.astype("bool") + # Strip any special tokens during detokenization (e.g. the start and + # end markers). In the future we could make this configurable. + padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) + padding_mask = padding_mask & ( + token_ids != self.tokenizer.start_token_id + ) + token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/mistral/mistral_causal_lm_preprocessor_test.py b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..dbacce37ce --- /dev/null +++ b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor_test.py @@ -0,0 +1,92 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( + MistralCausalLMPreprocessor, +) +from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_nlp.tests.test_case import TestCase + + +class MistralCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = MistralTokenizer( + # Generated using create_mistral_test_proto.py + proto=os.path.join( + self.get_test_data_dir(), "mistral_test_vocab.spm" + ) + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = (["the quick brown fox"],) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=MistralCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + }, + [[3, 8, 4, 6, 0, 0, 0, 0]], # Pass through labels. + [[1, 1, 1, 1, 0, 0, 0, 0]], # Pass through sample_weights. + ), + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + + preprocessor = MistralCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[3, 8, 4, 6, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[8, 4, 6, 0, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "the quick brown fox" + preprocessor = MistralCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 3, 8, 4, 6, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 3, 8, 4, 6, 0, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0], + } + preprocessor = MistralCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the quick brown fox") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in MistralCausalLMPreprocessor.presets: + self.run_preset_test( + cls=MistralCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/mistral/mistral_causal_lm_test.py b/keras_nlp/models/mistral/mistral_causal_lm_test.py new file mode 100644 index 0000000000..3f9d7fab36 --- /dev/null +++ b/keras_nlp/models/mistral/mistral_causal_lm_test.py @@ -0,0 +1,130 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import patch + +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.mistral.mistral_backbone import MistralBackbone +from keras_nlp.models.mistral.mistral_causal_lm import MistralCausalLM +from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( + MistralCausalLMPreprocessor, +) +from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_nlp.tests.test_case import TestCase + + +class MistralCausalLMTest(TestCase): + def setUp(self): + self.preprocessor = MistralCausalLMPreprocessor( + MistralTokenizer( + # Generated using create_mistral_test_proto.py + proto=os.path.join( + self.get_test_data_dir(), "mistral_test_vocab.spm" + ) + ), + sequence_length=8, + ) + self.backbone = MistralBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=8, + intermediate_dim=16, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = (["the quick brown fox", "the earth is round"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=MistralCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, 10), + ) + + def test_generate(self): + causal_lm = MistralCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :5], + prompt_ids["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + prompt_ids["padding_mask"][:, :5], + ) + + def test_early_stopping(self): + causal_lm = MistralCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the earth"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = MistralCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the quick brown fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the quick brown fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MistralCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in MistralCausalLM.presets: + self.run_preset_test( + cls=MistralCausalLM, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/mistral/mistral_layer_norm.py b/keras_nlp/models/mistral/mistral_layer_norm.py index 9f9ddf26b5..e714a8540d 100644 --- a/keras_nlp/models/mistral/mistral_layer_norm.py +++ b/keras_nlp/models/mistral/mistral_layer_norm.py @@ -32,7 +32,6 @@ def build(self, input_shape): trainable=True, shape=(self._dim,), initializer="ones", - dtype=self.compute_dtype, ) self.built = True diff --git a/keras_nlp/models/mistral/mistral_preprocessor.py b/keras_nlp/models/mistral/mistral_preprocessor.py index d5d838303e..38dc6da5b6 100644 --- a/keras_nlp/models/mistral/mistral_preprocessor.py +++ b/keras_nlp/models/mistral/mistral_preprocessor.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.models.mistral.mistral_presets import backbone_presets from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( @@ -121,15 +123,21 @@ def __init__( ): super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.add_start_token = add_start_token self.add_end_token = add_end_token self.sequence_length = sequence_length + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. self.packer = StartEndPacker( start_value=self.tokenizer.start_token_id, end_value=self.tokenizer.end_token_id, - sequence_length=sequence_length, + sequence_length=self.sequence_length, return_padding_mask=True, ) + self.built = True def get_config(self): config = super().get_config() @@ -170,6 +178,21 @@ def call( } return pack_x_y_sample_weight(x, y, sample_weight) + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return MistralTokenizer + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/mistral/mistral_preprocessor_test.py b/keras_nlp/models/mistral/mistral_preprocessor_test.py index 40528fd4e8..e3ddd38f6f 100644 --- a/keras_nlp/models/mistral/mistral_preprocessor_test.py +++ b/keras_nlp/models/mistral/mistral_preprocessor_test.py @@ -14,6 +14,8 @@ import os +import pytest + from keras_nlp.models.mistral.mistral_preprocessor import MistralPreprocessor from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer from keras_nlp.tests.test_case import TestCase @@ -38,7 +40,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=MistralPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, @@ -57,3 +59,12 @@ def test_errors_for_2d_list_input(self): ambiguous_input = [["one", "two"], ["three", "four"]] with self.assertRaises(ValueError): preprocessor(ambiguous_input) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in MistralPreprocessor.presets: + self.run_preset_test( + cls=MistralPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/mistral/mistral_presets.py b/keras_nlp/models/mistral/mistral_presets.py new file mode 100644 index 0000000000..82a2ec44f6 --- /dev/null +++ b/keras_nlp/models/mistral/mistral_presets.py @@ -0,0 +1,38 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Mistral model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = { + "mistral_7b_en": { + "metadata": { + "description": "Mistral 7B base model", + "params": 7241732096, + "official_name": "Mistral", + "path": "mistral", + "model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md", + }, + "kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/3", + }, + "mistral_instruct_7b_en": { + "metadata": { + "description": "Mistral 7B instruct model", + "params": 7241732096, + "official_name": "Mistral", + "path": "mistral", + "model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md", + }, + "kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/3", + }, +} diff --git a/keras_nlp/models/mistral/mistral_tokenizer.py b/keras_nlp/models/mistral/mistral_tokenizer.py index 12636f69f1..59a00d302f 100644 --- a/keras_nlp/models/mistral/mistral_tokenizer.py +++ b/keras_nlp/models/mistral/mistral_tokenizer.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy + from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.mistral.mistral_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer +from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.MistralTokenizer") @@ -77,3 +81,7 @@ def set_proto(self, proto): else: self.start_token_id = None self.end_token_id = None + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/mistral/mistral_tokenizer_test.py b/keras_nlp/models/mistral/mistral_tokenizer_test.py index ea9e04f67d..6b700bf711 100644 --- a/keras_nlp/models/mistral/mistral_tokenizer_test.py +++ b/keras_nlp/models/mistral/mistral_tokenizer_test.py @@ -14,6 +14,8 @@ import os +import pytest + from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer from keras_nlp.tests.test_case import TestCase @@ -44,3 +46,21 @@ def test_errors_missing_special_tokens(self): self.get_test_data_dir(), "no_special_token_vocab.spm" ) ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=MistralTokenizer, + preset="mistral_7b_en", + input_data=["The quick brown fox."], + expected_output=[[415, 2936, 9060, 285, 1142, 28723]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in MistralTokenizer.presets: + self.run_preset_test( + cls=MistralTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/mistral/mistral_transformer_decoder.py b/keras_nlp/models/mistral/mistral_transformer_decoder.py index 9b6f7fdbf8..7c90ab91b9 100644 --- a/keras_nlp/models/mistral/mistral_transformer_decoder.py +++ b/keras_nlp/models/mistral/mistral_transformer_decoder.py @@ -36,7 +36,7 @@ def __init__( num_key_value_heads, rope_max_wavelength=10000, rope_scaling_factor=1.0, - activation="relu", + activation="silu", layer_norm_epsilon=1e-5, kernel_initializer="glorot_uniform", sliding_window=512, @@ -73,20 +73,20 @@ def build(self, decoder_sequence_shape): sliding_window=self.sliding_window, kernel_initializer=clone_initializer(self.kernel_initializer), dropout=self.dropout, - dtype=self.compute_dtype, + dtype=self.dtype_policy, name="self_attention", ) self._self_attention_layer.build(decoder_sequence_shape) self._self_attention_layernorm = MistralLayerNormalization( epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, name="self_attention_layernorm", - dtype=self.compute_dtype, ) self._self_attention_layernorm.build(decoder_sequence_shape) self._self_attention_dropout = keras.layers.Dropout( rate=self.dropout, - dtype=self.compute_dtype, + dtype=self.dtype_policy, name="self_attention_dropout", ) @@ -95,7 +95,7 @@ def build(self, decoder_sequence_shape): self.intermediate_dim, kernel_initializer=clone_initializer(self.kernel_initializer), use_bias=False, - dtype=self.compute_dtype, + dtype=self.dtype_policy, name="feedforward_intermediate_dense", ) self._feedforward_intermediate_dense.build(decoder_sequence_shape) @@ -105,6 +105,7 @@ def build(self, decoder_sequence_shape): activation=self.activation, kernel_initializer=clone_initializer(self.kernel_initializer), use_bias=False, + dtype=self.dtype_policy, name="feedforward_gate_dense", ) self._feedforward_gate_dense.build(decoder_sequence_shape) @@ -113,7 +114,7 @@ def build(self, decoder_sequence_shape): self.hidden_dim, kernel_initializer=clone_initializer(self.kernel_initializer), use_bias=False, - dtype=self.compute_dtype, + dtype=self.dtype_policy, name="feedforward_output_dense", ) @@ -125,8 +126,8 @@ def build(self, decoder_sequence_shape): self._feedforward_layernorm = MistralLayerNormalization( epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, name="feedforward_layernorm", - dtype=self.compute_dtype, ) self._feedforward_layernorm.build(decoder_sequence_shape) @@ -145,6 +146,8 @@ def call( decoder_sequence=decoder_sequence, decoder_padding_mask=decoder_padding_mask, decoder_attention_mask=decoder_attention_mask, + self_attention_cache=self_attention_cache, + self_attention_cache_update_index=self_attention_cache_update_index, ) residual = decoder_sequence @@ -184,23 +187,36 @@ def _compute_self_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask, + self_attention_cache, + self_attention_cache_update_index, ): decoder_mask = merge_padding_and_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask ) batch_size = ops.shape(decoder_sequence)[0] input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. + if self_attention_cache is not None: + input_length = ops.shape(self_attention_cache)[2] + + cache_update_index = ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ) # Mistral uses a banded attention mask causal_mask_lower = compute_causal_mask( - batch_size, input_length, output_length, 0 + batch_size, input_length, output_length, cache_update_index ) # Below is a workaround for `ops.triu` for Keras 2. # TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is removed. # causal_mask = ops.triu(causal_mask_lower, k=-self.sliding_window) - i = ops.arange(output_length)[:, None] + i = ops.arange(output_length)[:, None] + cache_update_index j = ops.arange(input_length)[None, :] - causal_mask_upper = ops.cast(i <= j + self.sliding_window, "int32") + causal_mask_upper = ops.cast(i < j + self.sliding_window, "int32") causal_mask = ops.minimum(causal_mask_lower, causal_mask_upper) return ( diff --git a/keras_nlp/models/opt/opt_backbone.py b/keras_nlp/models/opt/opt_backbone.py index ff1495ba9f..0b98a6c64e 100644 --- a/keras_nlp/models/opt/opt_backbone.py +++ b/keras_nlp/models/opt/opt_backbone.py @@ -57,6 +57,10 @@ class OPTBackbone(Backbone): can consume. If `None`, `max_sequence_length` uses the value from sequence length. This determines the variable shape for positional embeddings. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. Examples: ```python @@ -91,27 +95,22 @@ def __init__( intermediate_dim, dropout=0.1, max_sequence_length=2048, + dtype=None, **kwargs, ): - # Decoder inputs. - token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") - padding_mask = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" - ) - - # Embed tokens and positions. - embedding_layer = TokenAndPositionEmbedding( + # === Layers === + self.embeddings = TokenAndPositionEmbedding( vocabulary_size=vocabulary_size, sequence_length=max_sequence_length, embedding_dim=hidden_dim, embeddings_initializer=opt_kernel_initializer(), + dtype=dtype, name="embeddings", ) - x = embedding_layer(token_ids) - - # Apply successive transformer decoder blocks. + self.token_embedding = self.embeddings.token_embedding + self.transformer_layers = [] for i in range(num_layers): - x = TransformerDecoder( + layer = TransformerDecoder( intermediate_dim=intermediate_dim, num_heads=num_heads, dropout=dropout, @@ -119,28 +118,38 @@ def __init__( layer_norm_epsilon=1e-5, normalize_first=True, kernel_initializer=opt_kernel_initializer(), + dtype=dtype, name=f"transformer_layer_{i}", - )(x, decoder_padding_mask=padding_mask) - - # Add a final layer norm. - x = keras.layers.LayerNormalization( - name="layer_norm", + ) + self.transformer_layers.append(layer) + self.layer_norm = keras.layers.LayerNormalization( axis=-1, epsilon=1e-5, - dtype="float32", - )(x) + dtype=dtype, + name="layer_norm", + ) - # Instantiate using Functional API Model constructor + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.embeddings(token_id_input) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, decoder_padding_mask=padding_mask_input) + x = self.layer_norm(x) super().__init__( inputs={ - "token_ids": token_ids, - "padding_mask": padding_mask, + "token_ids": token_id_input, + "padding_mask": padding_mask_input, }, outputs=x, **kwargs, ) - # All references to `self` below this line + # === Config === self.vocabulary_size = vocabulary_size self.num_layers = num_layers self.num_heads = num_heads @@ -148,7 +157,6 @@ def __init__( self.intermediate_dim = intermediate_dim self.dropout = dropout self.max_sequence_length = max_sequence_length - self.token_embedding = embedding_layer.token_embedding def get_config(self): return { diff --git a/keras_nlp/models/opt/opt_causal_lm.py b/keras_nlp/models/opt/opt_causal_lm.py index 6197a87ffd..2ca8ee07b4 100644 --- a/keras_nlp/models/opt/opt_causal_lm.py +++ b/keras_nlp/models/opt/opt_causal_lm.py @@ -155,23 +155,21 @@ def __init__( preprocessor=None, **kwargs, ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === inputs = backbone.input hidden_states = backbone(inputs) outputs = backbone.token_embedding(hidden_states, reverse=True) - - # Instantiate using Functional API Model constructor. super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - self.backbone = backbone - self.preprocessor = preprocessor - self.generate_function = None - self._sampler = None - # Default compilation + # === Default compilation === self.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(2e-5), @@ -216,21 +214,19 @@ def call_with_cache( the final hidden representation of the input tokens, and `cache` is the decoding cache. """ - x = self.backbone.get_layer("embeddings")( - token_ids, start_index=cache_update_index - ) + x = self.backbone.embeddings(token_ids, start_index=cache_update_index) # Each decoder layer has a cache; we update them separately. caches = [] - for i in range(self.backbone.num_layers): + for i, transformer_layer in enumerate(self.backbone.transformer_layers): current_cache = cache[:, i, ...] - x, next_cache = self.backbone.get_layer(f"transformer_layer_{i}")( + x, next_cache = transformer_layer( x, self_attention_cache=current_cache, self_attention_cache_update_index=cache_update_index, ) caches.append(next_cache) cache = ops.stack(caches, axis=1) - x = self.backbone.get_layer("layer_norm")(x) + x = self.backbone.layer_norm(x) hidden_states = x logits = self.backbone.token_embedding(hidden_states, reverse=True) return logits, hidden_states, cache @@ -298,6 +294,7 @@ def next(prompt, cache, index): mask=padding_mask, end_token_id=end_token_id, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. diff --git a/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py b/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py index 9ba6851d4b..e04436f092 100644 --- a/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py @@ -39,7 +39,7 @@ def setUp(self): self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=OPTCausalLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/opt/opt_preprocessor.py b/keras_nlp/models/opt/opt_preprocessor.py index cdca904870..8f52bb67e6 100644 --- a/keras_nlp/models/opt/opt_preprocessor.py +++ b/keras_nlp/models/opt/opt_preprocessor.py @@ -120,10 +120,10 @@ def __init__( super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.sequence_length = sequence_length self.add_start_token = add_start_token self.add_end_token = add_end_token - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -176,6 +176,17 @@ def call( } return pack_x_y_sample_weight(x, y, sample_weight) + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def presets(cls): return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/opt/opt_preprocessor_test.py b/keras_nlp/models/opt/opt_preprocessor_test.py index b80c409b92..901efc7bee 100644 --- a/keras_nlp/models/opt/opt_preprocessor_test.py +++ b/keras_nlp/models/opt/opt_preprocessor_test.py @@ -37,7 +37,7 @@ def setUp(self): self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=OPTPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/roberta/roberta_backbone.py b/keras_nlp/models/roberta/roberta_backbone.py index 8495b5cb69..1ab61eeeb7 100644 --- a/keras_nlp/models/roberta/roberta_backbone.py +++ b/keras_nlp/models/roberta/roberta_backbone.py @@ -61,6 +61,10 @@ class RobertaBackbone(Backbone): consume. The sequence length of the input must be less than `max_sequence_length` default value. This determines the variable shape for positional embeddings. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. Examples: ```python @@ -96,60 +100,66 @@ def __init__( intermediate_dim, dropout=0.1, max_sequence_length=512, + dtype=None, **kwargs, ): - # Inputs - token_id_input = keras.Input( - shape=(None,), dtype="int32", name="token_ids" - ) - padding_mask = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" - ) - - # Embed tokens and positions. - embedding_layer = TokenAndPositionEmbedding( + # === Layers === + self.embeddings = TokenAndPositionEmbedding( vocabulary_size=vocabulary_size, sequence_length=max_sequence_length, embedding_dim=hidden_dim, embeddings_initializer=roberta_kernel_initializer(), + dtype=dtype, name="embeddings", ) - embedding = embedding_layer(token_id_input) - - # Sum, normalize and apply dropout to embeddings. - x = keras.layers.LayerNormalization( - name="embeddings_layer_norm", + self.token_embedding = self.embeddings.token_embedding + self.embeddings_layer_norm = keras.layers.LayerNormalization( axis=-1, epsilon=1e-5, # Original paper uses this epsilon value - dtype="float32", - )(embedding) - x = keras.layers.Dropout( + dtype=dtype, + name="embeddings_layer_norm", + ) + self.embeddings_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="embeddings_dropout", - )(x) - - # Apply successive transformer encoder blocks. + ) + self.transformer_layers = [] for i in range(num_layers): - x = TransformerEncoder( + layer = TransformerEncoder( num_heads=num_heads, intermediate_dim=intermediate_dim, activation="gelu", dropout=dropout, layer_norm_epsilon=1e-5, kernel_initializer=roberta_kernel_initializer(), + dtype=dtype, name=f"transformer_layer_{i}", - )(x, padding_mask=padding_mask) + ) + self.transformer_layers.append(layer) - # Instantiate using Functional API Model constructor + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.embeddings(token_id_input) + x = self.embeddings_layer_norm(x) + x = self.embeddings_dropout(x) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, padding_mask=padding_mask_input) super().__init__( inputs={ "token_ids": token_id_input, - "padding_mask": padding_mask, + "padding_mask": padding_mask_input, }, outputs=x, **kwargs, ) - # All references to `self` below this line + + # === Config === self.vocabulary_size = vocabulary_size self.num_layers = num_layers self.num_heads = num_heads @@ -158,7 +168,6 @@ def __init__( self.dropout = dropout self.max_sequence_length = max_sequence_length self.start_token_index = 0 - self.token_embedding = embedding_layer.token_embedding def get_config(self): config = super().get_config() diff --git a/keras_nlp/models/roberta/roberta_classifier.py b/keras_nlp/models/roberta/roberta_classifier.py index 9098d95429..887bc657d4 100644 --- a/keras_nlp/models/roberta/roberta_classifier.py +++ b/keras_nlp/models/roberta/roberta_classifier.py @@ -144,38 +144,54 @@ def __init__( dropout=0.0, **kwargs, ): - inputs = backbone.input + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.pooled_dropout = keras.layers.Dropout( + dropout, + dtype=backbone.dtype_policy, + name="pooled_dropout", + ) hidden_dim = hidden_dim or backbone.hidden_dim - - x = backbone(inputs)[:, backbone.start_token_index, :] - x = keras.layers.Dropout(dropout, name="pooled_dropout")(x) - x = keras.layers.Dense( - hidden_dim, activation="tanh", name="pooled_dense" - )(x) - x = keras.layers.Dropout(dropout, name="classifier_dropout")(x) - outputs = keras.layers.Dense( + self.pooled_dense = keras.layers.Dense( + hidden_dim, + activation="tanh", + dtype=backbone.dtype_policy, + name="pooled_dense", + ) + self.output_dropout = keras.layers.Dropout( + dropout, + dtype=backbone.dtype_policy, + name="output_dropout", + ) + self.output_dense = keras.layers.Dense( num_classes, kernel_initializer=roberta_kernel_initializer(), activation=activation, + dtype=backbone.dtype_policy, name="logits", - )(x) + ) - # Instantiate using Functional API Model constructor + # === Functional Model === + inputs = backbone.input + x = backbone(inputs)[:, backbone.start_token_index, :] + x = self.pooled_dropout(x) + x = self.pooled_dense(x) + x = self.output_dropout(x) + outputs = self.output_dense(x) super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - # All references to `self` below this line - self.backbone = backbone - self.preprocessor = preprocessor + + # === Config === self.num_classes = num_classes self.activation = keras.activations.get(activation) self.hidden_dim = hidden_dim self.dropout = dropout - # Default compilation + # === Default compilation === logit_output = self.activation == keras.activations.linear self.compile( loss=keras.losses.SparseCategoricalCrossentropy( diff --git a/keras_nlp/models/roberta/roberta_masked_lm.py b/keras_nlp/models/roberta/roberta_masked_lm.py index 1517f25914..bf96189860 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm.py +++ b/keras_nlp/models/roberta/roberta_masked_lm.py @@ -103,6 +103,19 @@ def __init__( preprocessor=None, **kwargs, ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.masked_lm_head = MaskedLMHead( + vocabulary_size=backbone.vocabulary_size, + token_embedding=backbone.token_embedding, + intermediate_activation="gelu", + kernel_initializer=roberta_kernel_initializer(), + dtype=backbone.dtype_policy, + name="mlm_head", + ) + + # === Functional Model === inputs = { **backbone.input, "mask_positions": keras.Input( @@ -110,25 +123,16 @@ def __init__( ), } backbone_outputs = backbone(backbone.input) - outputs = MaskedLMHead( - vocabulary_size=backbone.vocabulary_size, - token_embedding=backbone.token_embedding, - intermediate_activation="gelu", - kernel_initializer=roberta_kernel_initializer(), - name="mlm_head", - )(backbone_outputs, inputs["mask_positions"]) - - # Instantiate using Functional API Model constructor + outputs = self.masked_lm_head( + backbone_outputs, inputs["mask_positions"] + ) super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - # All references to `self` below this line - self.backbone = backbone - self.preprocessor = preprocessor + # === Default compilation === self.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(5e-5), diff --git a/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py b/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py index ae762079e2..a842e99f5d 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py @@ -44,7 +44,7 @@ def setUp(self): self.input_data = [" airplane airport"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=RobertaMaskedLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/roberta/roberta_preprocessor.py b/keras_nlp/models/roberta/roberta_preprocessor.py index 556561d17c..57a421590f 100644 --- a/keras_nlp/models/roberta/roberta_preprocessor.py +++ b/keras_nlp/models/roberta/roberta_preprocessor.py @@ -143,9 +143,9 @@ def __init__( super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.truncate = truncate self.sequence_length = sequence_length - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -180,6 +180,17 @@ def get_config(self): ) return config + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return RobertaTokenizer diff --git a/keras_nlp/models/roberta/roberta_preprocessor_test.py b/keras_nlp/models/roberta/roberta_preprocessor_test.py index 5e7ad77514..699742ea08 100644 --- a/keras_nlp/models/roberta/roberta_preprocessor_test.py +++ b/keras_nlp/models/roberta/roberta_preprocessor_test.py @@ -41,7 +41,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=RobertaPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/t5/t5_backbone.py b/keras_nlp/models/t5/t5_backbone.py index 6e76094d71..5fb383458f 100644 --- a/keras_nlp/models/t5/t5_backbone.py +++ b/keras_nlp/models/t5/t5_backbone.py @@ -13,7 +13,6 @@ # limitations under the License. import copy -from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone @@ -23,7 +22,7 @@ from keras_nlp.utils.python_utils import classproperty -@keras_nlp_export("keras_nlp.models.T5Backbone") +@keras.saving.register_keras_serializable(package="keras_nlp") class T5Backbone(Backbone): """T5 encoder-decoder backbone model. @@ -67,7 +66,11 @@ class T5Backbone(Backbone): layer normalization layers in the Transformer layers. tie_embedding_weights: boolean. If `True`, the weights of the token embedding and the weights projecting language model outputs from - `hidden_dim` + `hidden_dim`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. """ def __init__( @@ -83,47 +86,26 @@ def __init__( use_gated_activation=True, layer_norm_epsilon=1e-06, tie_embedding_weights=True, + dtype=None, **kwargs, ): - # Encoder inputs - encoder_token_ids = keras.Input( - shape=(None,), dtype="int32", name="encoder_token_ids" - ) - encoder_padding_mask = keras.Input( - shape=(None,), dtype="int32", name="encoder_padding_mask" - ) - - # Decoder inputs. - decoder_token_ids = keras.Input( - shape=(None,), dtype="int32", name="decoder_token_ids" - ) - decoder_padding_mask = keras.Input( - shape=(None,), dtype="int32", name="decoder_padding_mask" - ) - # Token embedding layer. This layer is shared by encoder and decoder. - token_embedding_layer = ReversibleEmbedding( + self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, tie_weights=tie_embedding_weights, embeddings_initializer=keras.initializers.TruncatedNormal(1.0), + dtype=dtype, name="token_embedding", ) - - # ===== Encoder ===== - - # Embed tokens. - token_embedding = token_embedding_layer(encoder_token_ids) - x = keras.layers.Dropout( + self.encoder_embedding_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="encoder_embedding_dropout", - )(token_embedding) - - encoder_attention_mask = encoder_padding_mask[:, None, :] - - position_bias = None + ) + self.encoder_transformer_layers = [] for i in range(num_layers): - output = T5TransformerLayer( + layer = T5TransformerLayer( is_decoder=False, hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, @@ -134,40 +116,28 @@ def __init__( num_heads=num_heads, use_gated_activation=use_gated_activation, use_relative_attention_bias=bool(i == 0), + dtype=dtype, name=f"transformer_encoder_layer_{i}", - )( - x, - attention_mask=encoder_attention_mask, - position_bias=position_bias, - use_causal_mask=False, ) - if isinstance(output, tuple): - x, position_bias = output - - x = T5LayerNorm( + self.encoder_transformer_layers.append(layer) + self.encoder_layer_norm = T5LayerNorm( epsilon=layer_norm_epsilon, + dtype=dtype, name="encoder_output_layer_norm", - )(x) - x = keras.layers.Dropout( + ) + self.encoder_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="encoder_output_dropout", - )(x) - encoder_output = x - - # ===== Decoder ===== - - # Embed tokens. - token_embedding = token_embedding_layer(decoder_token_ids) - x = keras.layers.Dropout( + ) + self.decoder_embedding_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="decoder_embedding_dropout", - )(token_embedding) - - decoder_attention_mask = decoder_padding_mask[:, None, :] - - position_bias = None + ) + self.decoder_transformer_layers = [] for i in range(num_layers): - output = T5TransformerLayer( + layer = T5TransformerLayer( is_decoder=True, hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, @@ -178,8 +148,58 @@ def __init__( num_heads=num_heads, use_gated_activation=use_gated_activation, use_relative_attention_bias=bool(i == 0), + dtype=dtype, name=f"transformer_decoder_layer_{i}", - )( + ) + self.decoder_transformer_layers.append(layer) + self.decoder_layer_norm = T5LayerNorm( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="decoder_output_layer_norm", + ) + self.decoder_dropout = keras.layers.Dropout( + dropout, + dtype=dtype, + name="decoder_output_dropout", + ) + + # === Functional Model === + encoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_token_ids" + ) + encoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="encoder_padding_mask" + ) + decoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_token_ids" + ) + decoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_padding_mask" + ) + # Encoder. + x = self.token_embedding(encoder_token_id_input) + x = self.encoder_embedding_dropout(x) + encoder_attention_mask = encoder_padding_mask_input[:, None, :] + position_bias = None + for transformer_layer in self.encoder_transformer_layers: + output = transformer_layer( + x, + attention_mask=encoder_attention_mask, + position_bias=position_bias, + use_causal_mask=False, + ) + if isinstance(output, tuple): + x, position_bias = output + x = self.encoder_layer_norm(x) + x = self.encoder_dropout(x) + encoder_output = x + # Decoder. + x = self.token_embedding(decoder_token_id_input) + x = self.decoder_embedding_dropout(x) + decoder_attention_mask = decoder_padding_mask_input[:, None, :] + position_bias = None + for transformer_layer in self.decoder_transformer_layers: + output = transformer_layer( x, attention_mask=decoder_attention_mask, position_bias=position_bias, @@ -189,23 +209,15 @@ def __init__( ) if isinstance(output, tuple): x, position_bias = output - - x = T5LayerNorm( - epsilon=layer_norm_epsilon, - name="decoder_output_layer_norm", - )(x) - x = keras.layers.Dropout( - dropout, - name="decoder_output_dropout", - )(x) + x = self.decoder_layer_norm(x) + x = self.decoder_dropout(x) decoder_output = x - super().__init__( { - "encoder_token_ids": encoder_token_ids, - "encoder_padding_mask": encoder_padding_mask, - "decoder_token_ids": decoder_token_ids, - "decoder_padding_mask": decoder_padding_mask, + "encoder_token_ids": encoder_token_id_input, + "encoder_padding_mask": encoder_padding_mask_input, + "decoder_token_ids": decoder_token_id_input, + "decoder_padding_mask": decoder_padding_mask_input, }, outputs={ "encoder_sequence_output": encoder_output, @@ -213,7 +225,8 @@ def __init__( }, **kwargs, ) - # All references to `self` below this line + + # === Config === self.vocabulary_size = vocabulary_size self.hidden_dim = hidden_dim self.intermediate_dim = intermediate_dim @@ -225,7 +238,6 @@ def __init__( self.use_gated_activation = use_gated_activation self.layer_norm_epsilon = layer_norm_epsilon self.tie_embedding_weights = tie_embedding_weights - self.token_embedding = token_embedding_layer def get_config(self): config = super().get_config() diff --git a/keras_nlp/models/t5/t5_multi_head_attention.py b/keras_nlp/models/t5/t5_multi_head_attention.py index 77e7109efe..2e3647b1d5 100644 --- a/keras_nlp/models/t5/t5_multi_head_attention.py +++ b/keras_nlp/models/t5/t5_multi_head_attention.py @@ -45,36 +45,43 @@ def __init__( self.query_projector = keras.layers.Dense( self.inner_dim, use_bias=False, - name="query_projector", kernel_initializer=keras.initializers.RandomNormal( mean=0, stddev=(self.inner_dim * self.key_value_dim) ** -0.5 ), + dtype=self.dtype_policy, + name="query_projector", ) self.key_projector = keras.layers.Dense( self.inner_dim, use_bias=False, - name="key_projector", kernel_initializer=keras.initializers.RandomNormal( mean=0, stddev=self.inner_dim**-0.5 ), + dtype=self.dtype_policy, + name="key_projector", ) self.value_projector = keras.layers.Dense( self.inner_dim, use_bias=False, - name="value_projector", kernel_initializer=keras.initializers.RandomNormal( mean=0, stddev=self.inner_dim**-0.5 ), + dtype=self.dtype_policy, + name="value_projector", ) self.output_projector = keras.layers.Dense( self.hidden_dim, use_bias=False, - name="output_projector", kernel_initializer=keras.initializers.RandomNormal( mean=0, stddev=self.inner_dim**-0.5 ), + dtype=self.dtype_policy, + name="output_projector", + ) + self.dropout_layer = keras.layers.Dropout( + dropout, + dtype=self.dtype_policy, ) - self.dropout_layer = keras.layers.Dropout(dropout) if self.use_relative_attention_bias: self.relative_attention_bias = self.add_weight( @@ -298,7 +305,7 @@ def project( mask = (1.0 - ops.cast(mask, position_bias.dtype)) * -1e9 position_bias = position_bias + mask - scores += position_bias + scores += ops.cast(position_bias, scores.dtype) weights = ops.nn.softmax( scores, axis=-1 ) # (batch_size, num_heads, query_length, key_length) diff --git a/keras_nlp/models/t5/t5_tokenizer.py b/keras_nlp/models/t5/t5_tokenizer.py index b5dee49b85..5feb2d9ab8 100644 --- a/keras_nlp/models/t5/t5_tokenizer.py +++ b/keras_nlp/models/t5/t5_tokenizer.py @@ -13,13 +13,13 @@ # limitations under the License. import copy -from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras from keras_nlp.models.t5.t5_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer from keras_nlp.utils.python_utils import classproperty -@keras_nlp_export("keras_nlp.models.T5Tokenizer") +@keras.saving.register_keras_serializable(package="keras_nlp") class T5Tokenizer(SentencePieceTokenizer): """T5 tokenizer layer based on SentencePiece. diff --git a/keras_nlp/models/t5/t5_transformer_layer.py b/keras_nlp/models/t5/t5_transformer_layer.py index 27b4c9892c..ddff7a164a 100644 --- a/keras_nlp/models/t5/t5_transformer_layer.py +++ b/keras_nlp/models/t5/t5_transformer_layer.py @@ -47,10 +47,17 @@ def __init__( num_heads=num_heads, dropout=dropout, use_relative_attention_bias=use_relative_attention_bias, + dtype=self.dtype_policy, name="self_attention", ) - self.self_attention_layer_norm = T5LayerNorm(layer_norm_epsilon) - self.self_attention_dropout = keras.layers.Dropout(dropout) + self.self_attention_layer_norm = T5LayerNorm( + layer_norm_epsilon, + dtype=self.dtype_policy, + ) + self.self_attention_dropout = keras.layers.Dropout( + dropout, + dtype=self.dtype_policy, + ) if self.is_decoder: self.cross_attention = T5MultiHeadAttention( @@ -60,39 +67,55 @@ def __init__( num_heads=num_heads, dropout=dropout, use_relative_attention_bias=False, + dtype=self.dtype_policy, name="cross_attention", ) - self.cross_attention_layer_norm = T5LayerNorm(layer_norm_epsilon) - self.cross_attention_dropout = keras.layers.Dropout(dropout) + self.cross_attention_layer_norm = T5LayerNorm( + layer_norm_epsilon, + dtype=self.dtype_policy, + ) + self.cross_attention_dropout = keras.layers.Dropout( + dropout, + dtype=self.dtype_policy, + ) self.input_projector = keras.layers.Dense( intermediate_dim, use_bias=False, - name="input_projector", activation=keras.activations.get(activation), kernel_initializer=keras.initializers.RandomNormal( mean=0, stddev=hidden_dim**-0.5 ), + dtype=self.dtype_policy, + name="input_projector", ) if self.use_gated_activation: self.gate_projector = keras.layers.Dense( intermediate_dim, use_bias=False, - name="gate_projector", kernel_initializer=keras.initializers.RandomNormal( mean=0, stddev=hidden_dim**-0.5 ), + dtype=self.dtype_policy, + name="gate_projector", ) self.output_projector = keras.layers.Dense( hidden_dim, use_bias=False, - name="output_projector", kernel_initializer=keras.initializers.RandomNormal( mean=0, stddev=intermediate_dim**-0.5 ), + dtype=self.dtype_policy, + name="output_projector", + ) + self.layer_norm = T5LayerNorm( + epsilon=layer_norm_epsilon, + dtype=self.dtype_policy, + ) + self.dropout_layer = keras.layers.Dropout( + dropout, + dtype=self.dtype_policy, ) - self.layer_norm = T5LayerNorm(epsilon=layer_norm_epsilon) - self.dropout_layer = keras.layers.Dropout(dropout) def call( self, @@ -108,8 +131,7 @@ def call( shape = ops.shape(hidden_states) batch_size, length = shape[0], shape[1] causal_mask = compute_causal_mask(batch_size, length, length) - attention_mask = ops.cast(attention_mask, "int32") - attention_mask = causal_mask & attention_mask + attention_mask = causal_mask & ops.cast(attention_mask, "bool") x = hidden_states # Intermediate result. diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index ee28e3a984..783cc0b41b 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -31,18 +31,25 @@ class Task(PipelineModel): """Base class for Task models.""" def __init__(self, *args, **kwargs): - self._backbone = None - self._preprocessor = None super().__init__(*args, **kwargs) self._functional_layer_ids = set( id(layer) for layer in self._flatten_layers() ) + self._initialized = True def __dir__(self): - # Temporary fixes for weight saving. This mimics the following PR for + if config.keras_3(): + return super().__dir__() + + # Temporary fixes for Keras 2 saving. This mimics the following PR for # older version of Keras: https://github.com/keras-team/keras/pull/18982 def filter_fn(attr): - if attr == "_layer_checkpoint_dependencies": + if attr in [ + "_layer_checkpoint_dependencies", + "transformer_layers", + "encoder_transformer_layers", + "decoder_transformer_layers", + ]: return False return id(getattr(self, attr)) not in self._functional_layer_ids @@ -99,17 +106,28 @@ def compile(self, optimizer="rmsprop", loss=None, **kwargs): super().compile(optimizer=optimizer, loss=loss, **kwargs) def preprocess_samples(self, x, y=None, sample_weight=None): - return self.preprocessor(x, y=y, sample_weight=sample_weight) + if self.preprocessor is not None: + return self.preprocessor(x, y=y, sample_weight=sample_weight) + else: + return super().preprocess_samples(x, y, sample_weight) def __setattr__(self, name, value): - # Work around torch setattr for properties. - if name in ["backbone", "preprocessor"]: + # Work around setattr issues for Keras 2 and Keras 3 torch backend. + # Since all our state is covered by functional model we can route + # around custom setattr calls. + is_property = isinstance(getattr(type(self), name, None), property) + is_unitialized = not hasattr(self, "_initialized") + is_torch = config.backend() == "torch" + is_keras_2 = not config.keras_3() + if is_torch and (is_property or is_unitialized): + return object.__setattr__(self, name, value) + if is_keras_2 and is_unitialized: return object.__setattr__(self, name, value) return super().__setattr__(name, value) @property def backbone(self): - """A `keras.Model` instance providing the backbone submodel.""" + """A `keras.Model` instance providing the backbone sub-model.""" return self._backbone @backbone.setter @@ -123,7 +141,6 @@ def preprocessor(self): @preprocessor.setter def preprocessor(self, value): - self.include_preprocessing = value is not None self._preprocessor = value def get_config(self): @@ -203,9 +220,14 @@ def from_preset( # Backbone case. if preset_cls == cls.backbone_cls: + # Forward dtype to the backbone. + config_overrides = {} + if "dtype" in kwargs: + config_overrides["dtype"] = kwargs.pop("dtype") backbone = load_from_preset( preset, load_weights=load_weights, + config_overrides=config_overrides, ) if "preprocessor" in kwargs: preprocessor = kwargs.pop("preprocessor") diff --git a/keras_nlp/models/task_test.py b/keras_nlp/models/task_test.py index 09fe1b0086..bf82e4fa68 100644 --- a/keras_nlp/models/task_test.py +++ b/keras_nlp/models/task_test.py @@ -32,11 +32,11 @@ def __init__(self, **kwargs): class SimpleTask(Task): def __init__(self, preprocessor=None, activation=None, **kwargs): + self.preprocessor = preprocessor + self.activation = keras.activations.get(activation) inputs = keras.Input(shape=(5,)) outputs = keras.layers.Dense(5)(inputs) super().__init__(inputs, outputs, **kwargs) - self.preprocessor = preprocessor - self.activation = keras.activations.get(activation) class TestTask(TestCase): diff --git a/keras_nlp/models/whisper/whisper_audio_feature_extractor.py b/keras_nlp/models/whisper/whisper_audio_feature_extractor.py index e41519bbc9..5fade1d63b 100644 --- a/keras_nlp/models/whisper/whisper_audio_feature_extractor.py +++ b/keras_nlp/models/whisper/whisper_audio_feature_extractor.py @@ -17,7 +17,7 @@ import numpy as np import tensorflow as tf -from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras from keras_nlp.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) @@ -26,7 +26,7 @@ from keras_nlp.utils.python_utils import format_docstring -@keras_nlp_export("keras_nlp.models.WhisperAudioFeatureExtractor") +@keras.saving.register_keras_serializable(package="keras_nlp") class WhisperAudioFeatureExtractor(PreprocessingLayer): """ Whisper audio feature extractor layer. diff --git a/keras_nlp/models/whisper/whisper_backbone.py b/keras_nlp/models/whisper/whisper_backbone.py index 32cfab215b..6da83c4b0d 100644 --- a/keras_nlp/models/whisper/whisper_backbone.py +++ b/keras_nlp/models/whisper/whisper_backbone.py @@ -14,7 +14,6 @@ import copy -from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.layers.modeling.position_embedding import PositionEmbedding @@ -38,7 +37,7 @@ def call(self, x): return ops.pad(x, [[0, 0], [1, 1], [0, 0]]) -@keras_nlp_export("keras_nlp.models.WhisperBackbone") +@keras.saving.register_keras_serializable(package="keras_nlp") class WhisperBackbone(Backbone): """A Whisper encoder-decoder network for speech. @@ -75,6 +74,10 @@ class WhisperBackbone(Backbone): positional embedding layer. max_decoder_sequence_length: int. The maximum sequence length that the text decoder can consume. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. Examples: @@ -112,79 +115,51 @@ def __init__( dropout=0.0, max_encoder_sequence_length=3000, max_decoder_sequence_length=448, + dtype=None, **kwargs, ): assert_tf_backend(self.__class__.__name__) - # Encoder inputs. Note that the encoder does not have a padding mask: - # https://github.com/openai/whisper/blob/v20230124/whisper/model.py#L132. - encoder_feature_input = keras.Input( - shape=(None, num_mels), dtype="float32", name="encoder_features" - ) - - # Decoder inputs. - decoder_token_id_input = keras.Input( - shape=(None,), dtype="int32", name="decoder_token_ids" - ) - decoder_padding_mask = keras.Input( - shape=(None,), dtype="int32", name="decoder_padding_mask" - ) - - # ====== Encoder ====== - - # Embed the input features. This consists of two 1D convolutional - # layers. - # For the first layer, we use `padding="same"` since that corresponds to - # a padding size of 1. - encoder_conv_layer_1 = keras.layers.Conv1D( + # === Layers === + self.encoder_conv_layer_1 = keras.layers.Conv1D( filters=hidden_dim, kernel_size=3, strides=1, padding="same", + dtype=dtype, name="encoder_token_embedding_conv_layer_1", ) - embedded_features = keras.activations.gelu( - encoder_conv_layer_1(encoder_feature_input), - approximate=False, - ) - - # For the second conv. layer, we cannot use `padding="same"` since - # that corresponds to a padding size of 1.5 (since stride is 2). Hence, - # we will manually pad the input. - embedded_features = Padder()(embedded_features) - encoder_conv_layer_2 = keras.layers.Conv1D( + self.encoder_conv_layer_2 = keras.layers.Conv1D( filters=hidden_dim, kernel_size=3, strides=2, padding="valid", + dtype=dtype, name="encoder_token_embedding_conv_layer_2", ) - embedded_features = keras.activations.gelu( - encoder_conv_layer_2(embedded_features), - approximate=False, + self.encoder_padder = Padder( + dtype=dtype, + name="encoder_padder", ) - - # The position embedding layer for the encoder is a sinusoidal embedding - # layer: https://github.com/openai/whisper/blob/v20230124/whisper/model.py#L137. - # Hence, we set it to be non-trainable. - # TODO: We can use `keras_nlp.layers.SinePositionEncoding` layer. - position_embedding = PositionEmbedding( + self.encoder_position_embedding = PositionEmbedding( initializer=whisper_kernel_initializer(), sequence_length=max_encoder_sequence_length // 2, + dtype=dtype, name="encoder_position_embedding", trainable=False, - )(embedded_features) - - # Sum and apply dropout to embeddings. - x = keras.layers.Add()((embedded_features, position_embedding)) - x = keras.layers.Dropout( + ) + self.encoder_embeddings_add = keras.layers.Add( + dtype=dtype, + name="encoder_embeddings_add", + ) + self.encoder_embeddings_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="encoder_embeddings_dropout", - )(x) - - # Apply successive transformer encoder blocks. + ) + self.encoder_transformer_layers = [] for i in range(num_layers): - x = WhisperEncoder( + layer = WhisperEncoder( num_heads=num_heads, intermediate_dim=intermediate_dim, activation=keras.activations.gelu, @@ -192,38 +167,33 @@ def __init__( dropout=dropout, kernel_initializer=whisper_kernel_initializer(), normalize_first=True, + dtype=dtype, name=f"transformer_encoder_layer_{i}", - )(x) - - x = keras.layers.LayerNormalization( - name="encoder_layer_norm", + ) + self.encoder_transformer_layers.append(layer) + self.encoder_layer_norm = keras.layers.LayerNormalization( axis=-1, epsilon=1e-5, - dtype="float32", - )(x) - encoder_output = x - - # ====== Decoder ====== - - # Embed tokens and positions. - embedding_layer = TokenAndPositionEmbedding( + dtype=dtype, + name="encoder_layer_norm", + ) + self.decoder_embeddings = TokenAndPositionEmbedding( vocabulary_size=vocabulary_size, sequence_length=max_decoder_sequence_length, embedding_dim=hidden_dim, embeddings_initializer=whisper_kernel_initializer(), + dtype=dtype, name="decoder_token_and_position_embedding", ) - x = embedding_layer(decoder_token_id_input) - - # Apply dropout to embeddings. - x = keras.layers.Dropout( + self.token_embedding = self.decoder_embeddings.token_embedding + self.decoder_embeddings_dropout = keras.layers.Dropout( dropout, + dtype=dtype, name="decoder_embeddings_dropout", - )(x) - - # Apply successive transformer decoder blocks. + ) + self.decoder_transformer_layers = [] for i in range(num_layers): - transformer_decoder_layer = WhisperDecoder( + layer = WhisperDecoder( intermediate_dim=intermediate_dim, num_heads=num_heads, dropout=dropout, @@ -231,28 +201,73 @@ def __init__( layer_norm_epsilon=1e-5, kernel_initializer=whisper_kernel_initializer(), normalize_first=True, + dtype=dtype, name=f"transformer_decoder_layer_{i}", ) - x = transformer_decoder_layer( + self.decoder_transformer_layers.append(layer) + self.decoder_layer_norm = keras.layers.LayerNormalization( + axis=-1, + epsilon=1e-5, + dtype=dtype, + name="decoder_layer_norm", + ) + + # === Functional Model === + # Note that the encoder does not have a padding mask: + # https://github.com/openai/whisper/blob/v20230124/whisper/model.py#L132. + encoder_feature_input = keras.Input( + shape=(None, num_mels), dtype="float32", name="encoder_features" + ) + decoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_token_ids" + ) + decoder_padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_padding_mask" + ) + # Encoder. + # Embed the input features. This consists of two 1D convolutional + # layers. + # For the first layer, we use `padding="same"` since that corresponds to + # a padding size of 1. + embedded_features = keras.activations.gelu( + self.encoder_conv_layer_1(encoder_feature_input), + approximate=False, + ) + # For the second conv. layer, we cannot use `padding="same"` since + # that corresponds to a padding size of 1.5 (since stride is 2). Hence, + # we will manually pad the input. + embedded_features = self.encoder_padder(embedded_features) + embedded_features = keras.activations.gelu( + self.encoder_conv_layer_2(embedded_features), + approximate=False, + ) + # The position embedding layer for the encoder is a sinusoidal embedding + # layer: https://github.com/openai/whisper/blob/v20230124/whisper/model.py#L137. + # Hence, we set it to be non-trainable. + # TODO: We can use `keras_nlp.layers.SinePositionEncoding` layer. + positions = self.encoder_position_embedding(embedded_features) + x = self.encoder_embeddings_add((embedded_features, positions)) + x = self.encoder_embeddings_dropout(x) + for transformer_layer in self.encoder_transformer_layers: + x = transformer_layer(x) + x = self.encoder_layer_norm(x) + encoder_output = x + # Decoder. + x = self.decoder_embeddings(decoder_token_id_input) + x = self.decoder_embeddings_dropout(x) + for transformer_layer in self.decoder_transformer_layers: + x = transformer_layer( decoder_sequence=x, encoder_sequence=encoder_output, - decoder_padding_mask=decoder_padding_mask, + decoder_padding_mask=decoder_padding_mask_input, ) - - x = keras.layers.LayerNormalization( - name="decoder_layer_norm", - axis=-1, - epsilon=1e-5, - dtype="float32", - )(x) + x = self.decoder_layer_norm(x) decoder_output = x - - # Instantiate using Functional API Model constructor super().__init__( inputs={ "encoder_features": encoder_feature_input, "decoder_token_ids": decoder_token_id_input, - "decoder_padding_mask": decoder_padding_mask, + "decoder_padding_mask": decoder_padding_mask_input, }, outputs={ "encoder_sequence_output": encoder_output, @@ -261,7 +276,7 @@ def __init__( **kwargs, ) - # All references to `self` below this line + # === Config === self.vocabulary_size = vocabulary_size self.num_layers = num_layers self.num_heads = num_heads @@ -271,7 +286,6 @@ def __init__( self.dropout = dropout self.max_encoder_sequence_length = max_encoder_sequence_length self.max_decoder_sequence_length = max_decoder_sequence_length - self.token_embedding = embedding_layer def get_config(self): config = super().get_config() diff --git a/keras_nlp/models/whisper/whisper_preprocessor.py b/keras_nlp/models/whisper/whisper_preprocessor.py index abcff0d770..5ddec8732b 100644 --- a/keras_nlp/models/whisper/whisper_preprocessor.py +++ b/keras_nlp/models/whisper/whisper_preprocessor.py @@ -16,7 +16,6 @@ from absl import logging -from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker from keras_nlp.models.preprocessor import Preprocessor @@ -32,7 +31,7 @@ from keras_nlp.utils.python_utils import classproperty -@keras_nlp_export("keras_nlp.models.WhisperPreprocessor") +@keras.saving.register_keras_serializable(package="keras_nlp") class WhisperPreprocessor(Preprocessor): """A Whisper preprocessing layer which handles audio and text input. @@ -169,11 +168,11 @@ def __init__( audio_feature_extractor = WhisperAudioFeatureExtractor() self.audio_feature_extractor = audio_feature_extractor self.tokenizer = tokenizer + self.decoder_packer = None self.decoder_sequence_length = decoder_sequence_length self.language = language self.task = task self.no_timestamps = no_timestamps - self.decoder_packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -307,6 +306,26 @@ def from_config(cls, config): return cls(**config) + @property + def decoder_sequence_length(self): + """The padded length of decoder input sequences.""" + return self._decoder_sequence_length + + @decoder_sequence_length.setter + def decoder_sequence_length(self, value): + self._decoder_sequence_length = value + if self.decoder_packer is not None: + self.decoder_packer.sequence_length = value + + @property + def sequence_length(self): + """Alias for `decoder_sequence_length`.""" + return self.decoder_sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self.decoder_sequence_length = value + @classproperty def audio_feature_extractor_cls(cls): return WhisperAudioFeatureExtractor diff --git a/keras_nlp/models/whisper/whisper_preprocessor_test.py b/keras_nlp/models/whisper/whisper_preprocessor_test.py index 6837dc8bfa..8517a6c102 100644 --- a/keras_nlp/models/whisper/whisper_preprocessor_test.py +++ b/keras_nlp/models/whisper/whisper_preprocessor_test.py @@ -66,10 +66,11 @@ def setUp(self): } def test_feature_extractor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=WhisperPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, + token_id_key="decoder_token_ids", ) def test_sequence_length_override(self): diff --git a/keras_nlp/models/whisper/whisper_tokenizer.py b/keras_nlp/models/whisper/whisper_tokenizer.py index 7b68dfd790..4446193738 100644 --- a/keras_nlp/models/whisper/whisper_tokenizer.py +++ b/keras_nlp/models/whisper/whisper_tokenizer.py @@ -15,7 +15,7 @@ import copy import json -from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras from keras_nlp.models.whisper.whisper_presets import backbone_presets from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer from keras_nlp.utils.python_utils import classproperty @@ -28,7 +28,7 @@ def _load_dict(dict_or_path): return dict_or_path -@keras_nlp_export("keras_nlp.models.WhisperTokenizer") +@keras.saving.register_keras_serializable(package="keras_nlp") class WhisperTokenizer(BytePairTokenizer): """Whisper text tokenizer using Byte-Pair Encoding subword segmentation. diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py index b83b4596b2..c74a0fd6fc 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py @@ -52,6 +52,10 @@ class XLMRobertaBackbone(roberta_backbone.RobertaBackbone): consume. The sequence length of the input must be less than `max_sequence_length` default value. This determines the variable shape for positional embeddings. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. Examples: ```python diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py index 67a9dd5bef..fcd8bfe9b8 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py @@ -157,30 +157,49 @@ def __init__( dropout=0.0, **kwargs, ): - inputs = backbone.input + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.pooled_dropout = keras.layers.Dropout( + dropout, + dtype=backbone.dtype_policy, + name="pooled_dropout", + ) hidden_dim = hidden_dim or backbone.hidden_dim - - x = backbone(inputs)[:, backbone.start_token_index, :] - x = keras.layers.Dropout(dropout, name="pooled_dropout")(x) - x = keras.layers.Dense( - hidden_dim, activation="tanh", name="pooled_dense" - )(x) - x = keras.layers.Dropout(dropout, name="classifier_dropout")(x) - outputs = keras.layers.Dense( + self.pooled_dense = keras.layers.Dense( + hidden_dim, + activation="tanh", + dtype=backbone.dtype_policy, + name="pooled_dense", + ) + self.output_dropout = keras.layers.Dropout( + dropout, + dtype=backbone.dtype_policy, + name="output_dropout", + ) + self.output_dense = keras.layers.Dense( num_classes, kernel_initializer=roberta_kernel_initializer(), activation=activation, + dtype=backbone.dtype_policy, name="logits", - )(x) + ) + # === Functional Model === + inputs = backbone.input + x = backbone(inputs)[:, backbone.start_token_index, :] + x = self.pooled_dropout(x) + x = self.pooled_dense(x) + x = self.output_dropout(x) + outputs = self.output_dense(x) # Instantiate using Functional API Model constructor super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - # All references to `self` below this line + + # === Config === self.backbone = backbone self.preprocessor = preprocessor self.num_classes = num_classes @@ -188,6 +207,7 @@ def __init__( self.hidden_dim = hidden_dim self.dropout = dropout + # === Default compilation === logit_output = self.activation == keras.activations.linear self.compile( loss=keras.losses.SparseCategoricalCrossentropy( @@ -198,9 +218,6 @@ def __init__( jit_compile=True, ) - def preprocess_samples(self, x, y=None, sample_weight=None): - return self.preprocessor(x, y=y, sample_weight=sample_weight) - def get_config(self): config = super().get_config() config.update( diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py index f0dfc85e84..e231f3dc7a 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py @@ -106,6 +106,19 @@ def __init__( preprocessor=None, **kwargs, ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.masked_lm_head = MaskedLMHead( + vocabulary_size=backbone.vocabulary_size, + token_embedding=backbone.token_embedding, + intermediate_activation="gelu", + kernel_initializer=roberta_kernel_initializer(), + dtype=backbone.dtype_policy, + name="mlm_head", + ) + + # === Functional Model === inputs = { **backbone.input, "mask_positions": keras.Input( @@ -113,25 +126,16 @@ def __init__( ), } backbone_outputs = backbone(backbone.input) - outputs = MaskedLMHead( - vocabulary_size=backbone.vocabulary_size, - token_embedding=backbone.token_embedding, - intermediate_activation="gelu", - kernel_initializer=roberta_kernel_initializer(), - name="mlm_head", - )(backbone_outputs, inputs["mask_positions"]) - - # Instantiate using Functional API Model constructor. + outputs = self.masked_lm_head( + backbone_outputs, inputs["mask_positions"] + ) super().__init__( inputs=inputs, outputs=outputs, - include_preprocessing=preprocessor is not None, **kwargs, ) - # All references to `self` below this line - self.backbone = backbone - self.preprocessor = preprocessor + # === Default compilation === self.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(5e-5), diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py index c1bfc7242a..6d77e71319 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py @@ -45,7 +45,7 @@ def setUp(self): self.input_data = ["the quick brown fox"] def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=XLMRobertaMaskedLMPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py index 23b48073f7..c94f5f2421 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py @@ -156,9 +156,9 @@ def __init__( super().__init__(**kwargs) self.tokenizer = tokenizer + self.packer = None self.truncate = truncate self.sequence_length = sequence_length - self.packer = None def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -193,6 +193,17 @@ def call(self, x, y=None, sample_weight=None): } return pack_x_y_sample_weight(x, y, sample_weight) + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + @classproperty def tokenizer_cls(cls): return XLMRobertaTokenizer diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py index 3c3bbf2612..85c76fa282 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py @@ -44,7 +44,7 @@ def setUp(self): ) def test_preprocessor_basics(self): - self.run_preprocessing_layer_test( + self.run_preprocessor_test( cls=XLMRobertaPreprocessor, init_kwargs=self.init_kwargs, input_data=self.input_data, diff --git a/keras_nlp/models/xlnet/relative_attention.py b/keras_nlp/models/xlnet/relative_attention.py index a11ae3fd9d..6ae56d1450 100644 --- a/keras_nlp/models/xlnet/relative_attention.py +++ b/keras_nlp/models/xlnet/relative_attention.py @@ -154,6 +154,7 @@ def build(self, content_stream_shape): output_rank - 1, [self._num_heads, self._key_dim] ), bias_axes=bias_axes if self._use_bias else None, + dtype=self.dtype_policy, name="query", **self._get_common_kwargs_for_sublayer(), ) @@ -168,6 +169,7 @@ def build(self, content_stream_shape): output_rank - 1, [self._num_heads, self._key_dim] ), bias_axes=bias_axes if self._use_bias else None, + dtype=self.dtype_policy, name="key", **self._get_common_kwargs_for_sublayer(), ) @@ -182,6 +184,7 @@ def build(self, content_stream_shape): output_rank - 1, [self._num_heads, self._value_dim] ), bias_axes=bias_axes if self._use_bias else None, + dtype=self.dtype_policy, name="value", **self._get_common_kwargs_for_sublayer(), ) @@ -197,6 +200,7 @@ def build(self, content_stream_shape): output_rank - 1, [self._query_shape[-1]] ), bias_axes=None, + dtype=self.dtype_policy, name="attention_output", **self._get_common_kwargs_for_sublayer(), ) @@ -213,6 +217,7 @@ def build(self, content_stream_shape): output_rank - 1, [self._num_heads, self._key_dim] ), bias_axes=None, + dtype=self.dtype_policy, name="encoding", **self._get_common_kwargs_for_sublayer(), ) diff --git a/keras_nlp/models/xlnet/xlnet_backbone.py b/keras_nlp/models/xlnet/xlnet_backbone.py index 1d1b4d2343..1fe6086436 100644 --- a/keras_nlp/models/xlnet/xlnet_backbone.py +++ b/keras_nlp/models/xlnet/xlnet_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.models.backbone import Backbone from keras_nlp.models.xlnet.xlnet_content_and_query_embedding import ( @@ -23,7 +22,7 @@ from keras_nlp.models.xlnet.xlnet_encoder import XLNetSegmentMatrixLayer -@keras_nlp_export("keras_nlp.models.XLNetBackbone") +@keras.saving.register_keras_serializable(package="keras_nlp") class XLNetBackbone(Backbone): """XLNet encoder network. @@ -52,6 +51,10 @@ class XLNetBackbone(Backbone): bias_initializer: string or `keras.initializers` initializer, defaults to "zeros". The bias initializer for the dense and multiheaded relative attention layers. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. Call arguments: token_ids: Indices of input sequence tokens in the vocabulary of shape @@ -101,44 +104,31 @@ def __init__( activation="gelu", kernel_initializer_range=0.02, bias_initializer="zeros", + dtype=None, **kwargs, ): - # Inputs - token_id_input = keras.Input( - shape=(None,), dtype="int32", name="token_ids" - ) - padding_mask = keras.Input( - shape=(None,), dtype="int32", name="padding_mask" - ) - segment_ids = keras.Input( - shape=(None,), dtype="int32", name="segment_ids" - ) - - # Content and Query Embedding - word_emb, pos_emb = ContentAndQueryEmbedding( + # === Layers === + self.content_query_embedding = ContentAndQueryEmbedding( vocabulary_size=vocabulary_size, hidden_dim=hidden_dim, dropout=dropout, + dtype=dtype, name="content_query_embedding", - )(token_id_input=token_id_input) - - # Apply XLNetAttentionMaskLayer and XLNetSegmentMatrixLayer Layers - # to get the processed attention masks and segment matrix. - attn_mask_content, attn_mask_query = XLNetAttentionMaskLayer( + ) + self.attn_mask_layer = XLNetAttentionMaskLayer( hidden_dim=hidden_dim, kernel_initializer_range=kernel_initializer_range, + dtype=dtype, name="encoder_block_attn_mask_layer", - )(padding_mask) - seg_mat = XLNetSegmentMatrixLayer(name="encoder_block_seg_mat_layer")( - segment_ids ) - - output_content = word_emb - - # Encoders + self.seg_mat_layer = XLNetSegmentMatrixLayer( + dtype=dtype, + name="encoder_block_seg_mat_layer", + ) head_dim = hidden_dim // num_heads + self.transformer_layers = [] for i in range(num_layers): - output_content, output_query = XLNetEncoder( + layer = XLNetEncoder( num_heads=num_heads, hidden_dim=hidden_dim, head_dim=head_dim, @@ -148,28 +138,55 @@ def __init__( layer_norm_epsilon=1e-12, kernel_initializer_range=kernel_initializer_range, bias_initializer=bias_initializer, + dtype=dtype, name=f"xlnet_encoder_{i}", - )( + ) + self.transformer_layers.append(layer) + self.dropout = keras.layers.Dropout( + dropout, + dtype=dtype, + name="dropout", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + segment_id_input = keras.Input( + shape=(None,), dtype="int32", name="segment_ids" + ) + # Content and Query Embedding + word_emb, pos_emb = self.content_query_embedding(token_id_input) + # Apply XLNetAttentionMaskLayer and XLNetSegmentMatrixLayer Layers + # to get the processed attention masks and segment matrix. + attn_mask_content, attn_mask_query = self.attn_mask_layer( + padding_mask_input + ) + seg_mat = self.seg_mat_layer(segment_id_input) + output_content = word_emb + for transformer_layer in self.transformer_layers: + output_content, output_query = transformer_layer( output_content=output_content, attn_mask_content=attn_mask_content, attn_mask_query=attn_mask_query, pos_emb=pos_emb, seg_mat=seg_mat, ) - - output = keras.layers.Dropout(dropout)(output_content) - + output = self.dropout(output_content) super().__init__( inputs={ "token_ids": token_id_input, - "padding_mask": padding_mask, - "segment_ids": segment_ids, + "padding_mask": padding_mask_input, + "segment_ids": segment_id_input, }, outputs=output, **kwargs, ) - # All references to `self` below this line + # === Config === self.vocabulary_size = vocabulary_size self.num_layers = num_layers self.num_heads = num_heads diff --git a/keras_nlp/models/xlnet/xlnet_content_and_query_embedding.py b/keras_nlp/models/xlnet/xlnet_content_and_query_embedding.py index a2bc2d0d07..2de2c31780 100644 --- a/keras_nlp/models/xlnet/xlnet_content_and_query_embedding.py +++ b/keras_nlp/models/xlnet/xlnet_content_and_query_embedding.py @@ -58,7 +58,8 @@ def positional_embedding(self, pos_seq, inv_freq, bsz=None): ops.shape(pos_emb)[0], ops.shape(pos_emb)[1] * bsz, ops.shape(pos_emb)[2], - ] + ], + dtype=self.compute_dtype, ) * pos_emb ) @@ -67,12 +68,14 @@ def positional_embedding(self, pos_seq, inv_freq, bsz=None): def relative_positional_encoding(self, qlen, klen, bsz=None, clamp_len=-1): """create relative positional encoding.""" - freq_seq = ops.arange(0, self.hidden_dim, 2.0, dtype=self.compute_dtype) + freq_seq = ops.arange(0, self.hidden_dim, 2.0, dtype="float32") + freq_seq = ops.cast(freq_seq, self.compute_dtype) inv_freq = 1 / (10000 ** (freq_seq / self.hidden_dim)) beg, end = klen, -qlen - fwd_pos_seq = ops.arange(beg, end, -1.0, dtype=self.compute_dtype) + fwd_pos_seq = ops.arange(beg, end, -1.0, dtype="float32") + fwd_pos_seq = ops.cast(fwd_pos_seq, self.compute_dtype) if clamp_len > 0: fwd_pos_seq = ops.clip( fwd_pos_seq, x_min=-clamp_len, x_max=clamp_len @@ -85,11 +88,14 @@ def build(self, input_shape): self.word_embed = keras.layers.Embedding( input_dim=self.vocabulary_size, output_dim=self.hidden_dim, + dtype=self.dtype_policy, name="word_embedding", ) self.word_embed.build(input_shape) - self.dropout_layer = keras.layers.Dropout(self.dropout) - + self.dropout_layer = keras.layers.Dropout( + self.dropout, + dtype=self.dtype_policy, + ) super().build(input_shape) def call( diff --git a/keras_nlp/models/xlnet/xlnet_encoder.py b/keras_nlp/models/xlnet/xlnet_encoder.py index bb8e56e4cc..6aa44b150e 100644 --- a/keras_nlp/models/xlnet/xlnet_encoder.py +++ b/keras_nlp/models/xlnet/xlnet_encoder.py @@ -94,25 +94,34 @@ def build(self, input_shape): key_dim=self.head_dim, kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer, + dtype=self.dtype_policy, name="rel_attn", ) self.layer_norm = keras.layers.LayerNormalization( - epsilon=self.layer_norm_epsilon, name="layer_norm_rel_attn" + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="layer_norm_rel_attn", ) self.layer_norm.build(input_shape) - self.dropout_attn = keras.layers.Dropout(self.dropout) + self.dropout_attn = keras.layers.Dropout( + self.dropout, + dtype=self.dtype_policy, + ) # Feed-Forward Part self.layer_norm_ff = keras.layers.LayerNormalization( - epsilon=self.layer_norm_epsilon, name="layer_norm_ff" + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="layer_norm_ff", ) self.layer_norm_ff.build(input_shape) self.feedforward_intermediate_dense = keras.layers.Dense( self.intermediate_dim, kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, name="feedforward_intermediate_dense", ) self.feedforward_intermediate_dense.build(input_shape) @@ -120,6 +129,7 @@ def build(self, input_shape): self.feedforward_output_dense = keras.layers.Dense( self.hidden_dim, kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, name="feedforward_output_dense", ) self.feedforward_output_dense.build( @@ -128,7 +138,10 @@ def build(self, input_shape): ) ) - self.dropout_ff = keras.layers.Dropout(self.dropout) + self.dropout_ff = keras.layers.Dropout( + self.dropout, + dtype=self.dtype_policy, + ) self.activation_function_ff = keras.activations.get(self.activation) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 9562f95d14..297ec203de 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -18,11 +18,8 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.BeamSampler") class BeamSampler(Sampler): """Beam Sampler class. @@ -42,55 +39,17 @@ class BeamSampler(Sampler): {{call_args}} Examples: - Return only the beam with the highest accumulated probability. ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((prompt_batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.BeamSampler()( - next=next, - prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzeeeeeee'] - ``` + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - Return all beams and their probabilities. - ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 8, len(int_lookup) - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache - - beams, probs = keras_nlp.samplers.BeamSampler(return_all_beams=True)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - - print(beams.shape) - # >>> (1, 5, 8) - print(probs.shape) - # >>> (1, 5) - print(["".join([int_lookup[i] for i in s]) for s in beams[0].numpy()]) - # >>> ['zzzzzeee', 'zzzzzeed', 'zzzzzeec', 'zzzzzeea', 'zzzzzeeb'] + # Pass by name to compile. + causal_lm.compile(sampler="beam") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.BeamSampler(num_beams=5) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ @@ -113,6 +72,7 @@ def __call__( mask=None, end_token_id=None, hidden_states=None, + model=None, ): batch_size, max_length = ops.shape(prompt)[0], ops.shape(prompt)[1] index = ops.cast(index, "int32") @@ -208,6 +168,7 @@ def gather_beams(x): body=body, loop_vars=(prompt, cache, index, log_probs), maximum_iterations=(max_length - index), + model=model, ) all_prompts = unflatten_beams(prompt) diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index bac65bcfbe..4259167c8c 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -17,11 +17,8 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.ContrastiveSampler") class ContrastiveSampler(Sampler): """Contrastive Sampler class. @@ -44,28 +41,16 @@ class ContrastiveSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters to [0, 26). - int_lookup = {i: chr(i + ord("a")) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - hidden_size = 5 - index = 5 - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, hidden_size)) - # A uniform distribution over our alphabet. - logits = np.ones((prompt_batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.ContrastiveSampler()( - next=next, - prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"), - index=index, - hidden_states=np.ones([batch_size, index, hidden_size]), - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> "zzzzzeeeeeee" + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Pass by name to compile. + causal_lm.compile(sampler="contrastive") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.ContrastiveSampler(k=5) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ @@ -88,6 +73,7 @@ def __call__( mask=None, end_token_id=None, hidden_states=None, + model=None, ): if hidden_states is None: raise ValueError( @@ -224,6 +210,7 @@ def gather_best_token(beams): body=body, loop_vars=(prompt, cache, index, logits, hidden_states), maximum_iterations=(max_length - index), + model=model, ) return prompt diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py index 8e178b7468..ee8a6ecc2d 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -15,11 +15,8 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.GreedySampler") class GreedySampler(Sampler): """Greedy sampler class. @@ -27,29 +24,18 @@ class GreedySampler(Sampler): This sampler is implemented on greedy search, i.e., always picking up the token of the largest probability as the next token. - Call arguments: - {{call_args}} - Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.GreedySampler()( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzaaaaaaa'] + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Pass by name to compile. + causal_lm.compile(sampler="greedy") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.GreedySampler() + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/random_sampler.py b/keras_nlp/samplers/random_sampler.py index b922d29b2a..1ff39c9f9b 100644 --- a/keras_nlp/samplers/random_sampler.py +++ b/keras_nlp/samplers/random_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.RandomSampler") class RandomSampler(Sampler): """Random Sampler class. @@ -37,24 +34,16 @@ class RandomSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, state, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, state + # Pass by name to compile. + causal_lm.compile(sampler="random") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.RandomSampler()( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzcpnjqij'] + # Pass by object to compile. + sampler = keras_nlp.samplers.RandomSampler(temperature=0.7) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index e28fbe9d6e..3ecf16ac28 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -17,33 +17,8 @@ from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.backend import random -from keras_nlp.utils.python_utils import format_docstring - -call_args_docstring = """next: A function which takes in the - `prompt, cache, index` of the current generation loop, and outputs - a tuple `(logits, hidden_states, cache)` with `logits` being the - logits of next token, `hidden_states` being the representation of - the next token, and `cache` for next iteration. - prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This - tensor will be iteratively updated column by column with new sampled - values, starting at `index`. - cache: Optional. A tensor or nested structure of tensors that will be - updated by each call to `next`. This can be used to cache - computations from early iterations of the generative loop. - index: Optional. The first index of `prompt` to start sampling at. - Usually this is set as the length of the shortest non-padded - sequence in `prompt`. - mask: Optional. A 2D integer tensor with the same shape as `prompt`. - Locations which are `True` in the mask are never updated during - sampling. Usually used to mark all locations in the dense prompt - tensor which were present in a user input. - end_token_id: Optional. The token marking the end of the sequence. If - specified, sampling will stop as soon as all sequences in the prompt - produce a `end_token_id` in a location where `mask` is `False`. -""" - - -@format_docstring(call_args=call_args_docstring) + + @keras_nlp_export("keras_nlp.samplers.Sampler") class Sampler: """Base sampler class. @@ -57,35 +32,32 @@ class Sampler: {{call_args}} This base class can be extended to implement different auto-regressive - sampling methods. Subclasses can either: - - - Override the `get_next_token()` method, which computes the next token - based on a probability distribution over all possible vocab entries. - - Override `__call__`, if the sampling method needs additional information - beyond the next tokens probability distribution to sample a sequence. - - Please check available subclass samplers for examples. + sampling methods. To do so, override the `get_next_token()` method, which + computes the next token based on a probability distribution over all + possible vocab entries. Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - # return a uniform distribution over our alphabet. - logits = ops.ones((batch_size, vocab_size)) - return logits, None, cache - - output = keras_nlp.samplers.GreedySampler()( - next=next, - prompt=ops.fill((batch_size, length,), char_lookup['z']), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzaaaaaaa'] + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Greedy search with some tokens forbidden. + class CustomSampler(keras_nlp.samplers.Sampler): + def __init__(self, forbidden_tokens, **kwargs): + super().__init__(**kwargs) + self.forbidden_tokens = forbidden_tokens + + def get_next_token(self, probs): + batch_size, vocab_size = keras.ops.shape(probs) + for id in self.forbidden_tokens: + update = keras.ops.zeros((batch_size, 1)) + probs = keras.ops.slice_update(probs, (0, id), update) + return keras.ops.argmax(probs, axis=-1) + + # 257 = "a" with a leading space, 262 = "the" with a leading space. + causal_lm.compile(sampler=CustomSampler(forbidden_tokens=[257, 262])) + causal_lm.summary() + causal_lm.generate(["That's strange"]) ``` """ @@ -120,6 +92,7 @@ def __call__( mask=None, end_token_id=None, hidden_states=None, + model=None, ): max_length = ops.shape(prompt)[-1] # Make sure `max_length` and `index` are the same dtype. @@ -161,6 +134,7 @@ def body(prompt, cache, index): body, loop_vars=(prompt, cache, index), maximum_iterations=(max_length - index), + model=model, ) return prompt @@ -175,32 +149,68 @@ def compute_probabilities(self, logits): probs = keras.activations.softmax(logits / self.temperature) return ops.cast(probs, logits_dtype) - def run_loop(self, cond, body, loop_vars=None, maximum_iterations=None): + def run_loop( + self, cond, body, model=None, loop_vars=None, maximum_iterations=None + ): """Run ops.while_loops with a `StatelessScope` if necessary.""" if config.backend() == "jax": + import itertools + + if model: + model_trainable_variables = model.trainable_variables + model_non_trainable_variables = model.non_trainable_variables + else: + model_trainable_variables = [] + model_non_trainable_variables = [] - def stateless_cond(variables, *loop_vars): + def stateless_cond(state, *loop_vars): return cond(*loop_vars) - def stateless_body(variables, *loop_vars): - mapping = zip(self.variables, variables) + def stateless_body(state, *loop_vars): + ( + sampler_variables, + trainable_variables, + non_trainable_variables, + ) = state + mapping = itertools.chain( + zip(self.variables, sampler_variables), + zip(model_trainable_variables, trainable_variables), + zip(model_non_trainable_variables, non_trainable_variables), + ) with keras.StatelessScope(state_mapping=mapping) as scope: loop_vars = body(*loop_vars) - variables = [] + sampler_variables = [] for v in self.variables: new_v = scope.get_current_value(v) - variables.append(new_v if new_v is not None else v) - return variables, *loop_vars + sampler_variables.append(new_v if new_v is not None else v) + state = ( + sampler_variables, + trainable_variables, + non_trainable_variables, + ) + return state, *loop_vars variables = [ops.convert_to_tensor(v) for v in self.variables] - variables, *loop_vars = ops.while_loop( + trainable_variables = [ + ops.convert_to_tensor(v) for v in model_trainable_variables + ] + non_trainable_variables = [ + ops.convert_to_tensor(v) for v in model_non_trainable_variables + ] + state = ( + variables, + trainable_variables, + non_trainable_variables, + ) + state, *loop_vars = ops.while_loop( cond=stateless_cond, body=stateless_body, - loop_vars=(variables, *loop_vars), + loop_vars=(state, *loop_vars), maximum_iterations=maximum_iterations, ) - [ref_v.assign(v) for ref_v, v in zip(self.variables, variables)] + for ref_v, v in zip(self.variables, state[0]): + ref_v.assign(v) else: loop_vars = ops.while_loop( cond=cond, diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index 3456694848..513dd738c7 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.TopKSampler") class TopKSampler(Sampler): """Top-K Sampler class. @@ -38,24 +35,16 @@ class TopKSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache + # Pass by name to compile. + causal_lm.compile(sampler="top_k") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.TopKSampler(k=3)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtypes="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzacbbcaa'] + # Pass by object to compile. + sampler = keras_nlp.samplers.TopKSampler(k=5, temperature=0.7) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index a04b39aa2b..326f5797a6 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.TopPSampler") class TopPSampler(Sampler): """Top-P Sampler class. @@ -46,24 +43,16 @@ class TopPSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache + # Pass by name to compile. + causal_lm.compile(sampler="top_p") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.TopPSampler(p=0.1)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzbabcccb'] + # Pass by object to compile. + sampler = keras_nlp.samplers.TopPSampler(p=0.1, k=1_000) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/tests/test_case.py b/keras_nlp/tests/test_case.py index 455a8569b7..0541ae6451 100644 --- a/keras_nlp/tests/test_case.py +++ b/keras_nlp/tests/test_case.py @@ -87,7 +87,7 @@ def run_layer_test( expected_num_non_trainable_weights=0, expected_num_non_trainable_variables=0, run_training_check=True, - run_mixed_precision_check=True, + run_precision_checks=True, ): """Run basic tests for a modeling layer.""" # Serialization test. @@ -142,6 +142,12 @@ def call(self, x): else: return self.layer(x) + input_data = tree.map_structure( + lambda x: ops.convert_to_numpy(x), input_data + ) + output_data = tree.map_structure( + lambda x: ops.convert_to_numpy(x), output_data + ) model = TestModel(layer) # Temporarily disable jit compilation on torch backend. jit_compile = config.backend() != "torch" @@ -181,24 +187,8 @@ def call(self, x): if run_training_check: run_training_step(layer, input_data, output_data) - # Never test mixed precision on torch CPU. Torch lacks support. - if run_mixed_precision_check and config.backend() == "torch": - import torch - - run_mixed_precision_check = torch.cuda.is_available() - - if run_mixed_precision_check: - layer = cls(**{**init_kwargs, "dtype": "mixed_float16"}) - if isinstance(input_data, dict): - output_data = layer(**input_data) - else: - output_data = layer(input_data) - for tensor in tree.flatten(output_data): - if is_float_dtype(tensor.dtype): - self.assertDTypeEqual(tensor, "float16") - for weight in layer.weights: - if is_float_dtype(weight.dtype): - self.assertDTypeEqual(weight, "float32") + if run_precision_checks: + self.run_precision_test(cls, init_kwargs, input_data) def run_preprocessing_layer_test( self, @@ -240,6 +230,42 @@ def run_preprocessing_layer_test( if expected_output: self.assertAllClose(output, expected_output) + def run_preprocessor_test( + self, + cls, + init_kwargs, + input_data, + expected_output=None, + expected_detokenize_output=None, + token_id_key="token_ids", + ): + """Run basic tests for a Model Preprocessor layer.""" + self.run_preprocessing_layer_test( + cls, + init_kwargs, + input_data, + expected_output=expected_output, + expected_detokenize_output=expected_detokenize_output, + ) + + layer = cls(**self.init_kwargs) + if isinstance(input_data, tuple): + output = layer(*input_data) + else: + output = layer(input_data) + output, _, _ = keras.utils.unpack_x_y_sample_weight(output) + shape = ops.shape(output[token_id_key]) + self.assertEqual(shape[-1], layer.sequence_length) + # Update the sequence length. + layer.sequence_length = 17 + if isinstance(input_data, tuple): + output = layer(*input_data) + else: + output = layer(input_data) + output, _, _ = keras.utils.unpack_x_y_sample_weight(output) + shape = ops.shape(output[token_id_key]) + self.assertEqual(shape[-1], 17) + def run_serialization_test(self, instance): """Check idempotency of serialize/deserialize. @@ -277,6 +303,40 @@ def run_serialization_test(self, instance): lst.remove("__annotations__") self.assertEqual(set(ref_dir), set(new_dir)) + def run_precision_test(self, cls, init_kwargs, input_data): + # Keras 2 has some errors as non-float32 precision. + if not config.keras_3(): + return + # Never test mixed precision on torch CPU. Torch lacks support. + if config.backend() == "torch": + import torch + + if not torch.cuda.is_available(): + return + + for policy in ["mixed_float16", "mixed_bfloat16", "bfloat16"]: + policy = keras.mixed_precision.Policy(policy) + layer = cls(**{**init_kwargs, "dtype": policy}) + if isinstance(layer, keras.Model): + output_data = layer(input_data) + elif isinstance(input_data, dict): + output_data = layer(**input_data) + else: + output_data = layer(input_data) + for tensor in tree.flatten(output_data): + if is_float_dtype(tensor.dtype): + self.assertDTypeEqual(tensor, policy.compute_dtype) + for weight in layer.weights: + if is_float_dtype(weight.dtype): + self.assertDTypeEqual(weight, policy.variable_dtype) + for sublayer in layer._flatten_layers(include_self=False): + if isinstance( + sublayer, (keras.layers.Softmax, keras.layers.InputLayer) + ): + continue + self.assertEqual(policy.compute_dtype, sublayer.compute_dtype) + self.assertEqual(policy.variable_dtype, sublayer.variable_dtype) + def run_model_saving_test( self, cls, @@ -304,6 +364,7 @@ def run_backbone_test( input_data, expected_output_shape, variable_length_data=None, + run_mixed_precision_check=True, ): """Run basic tests for a backbone, including compilation.""" backbone = cls(**init_kwargs) @@ -345,6 +406,8 @@ def run_backbone_test( name = re.sub("([a-z])([A-Z])", r"\1_\2", name).lower() self.assertRegexpMatches(backbone.name, name) + self.run_precision_test(cls, init_kwargs, input_data) + def run_task_test( self, cls, diff --git a/keras_nlp/tests/test_data/gemma_test_vocab.spm b/keras_nlp/tests/test_data/gemma_test_vocab.spm new file mode 100644 index 0000000000..a049c032c2 Binary files /dev/null and b/keras_nlp/tests/test_data/gemma_test_vocab.spm differ diff --git a/keras_nlp/tests/test_data/llama_test_vocab.spm b/keras_nlp/tests/test_data/llama_test_vocab.spm new file mode 100644 index 0000000000..d753476f53 Binary files /dev/null and b/keras_nlp/tests/test_data/llama_test_vocab.spm differ diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 55992a16d7..902af812e9 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -350,7 +350,8 @@ def set_vocabulary_and_merges(self, vocabulary, merges): f"`type(vocabulary)={type(vocabulary)}`." ) if isinstance(merges, str): - self.merges = [bp.rstrip() for bp in open(merges, encoding="utf-8")] + with open(merges, encoding="utf-8") as f: + self.merges = [bp.rstrip() for bp in f] elif isinstance(merges, Iterable): self.merges = list(merges) else: diff --git a/keras_nlp/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/tokenizers/sentence_piece_tokenizer.py index ae655aceb6..64e169939c 100644 --- a/keras_nlp/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/tokenizers/sentence_piece_tokenizer.py @@ -253,6 +253,8 @@ def tokenize(self, inputs): def detokenize(self, inputs): self._check_vocabulary() inputs, unbatched, _ = convert_to_ragged_batch(inputs) + # tf-text sentencepiece does not handle int64. + inputs = tf.cast(inputs, "int32") outputs = self._sentence_piece.detokenize(inputs) if unbatched: outputs = tf.squeeze(outputs, 0) diff --git a/keras_nlp/utils/pipeline_model.py b/keras_nlp/utils/pipeline_model.py index fa08aaf929..89a2f81822 100644 --- a/keras_nlp/utils/pipeline_model.py +++ b/keras_nlp/utils/pipeline_model.py @@ -142,27 +142,15 @@ def _split(t, start, end): class PipelineModel(keras.Model): """A model which allows automatically applying preprocessing.""" - def __init__(self, *args, include_preprocessing=True, **kwargs): + def __init__(self, *args, **kwargs): # Workaround for https://github.com/keras-team/keras/issues/17270 # Reset any attempt to overwrite this classes base class to this class # can continue to be used for functional and non-functional models. PipelineModel.__bases__ = (keras.Model,) super().__init__(*args, **kwargs) - self.include_preprocessing = include_preprocessing - - def preprocess_features(self, x): - """An overridable function which preprocesses features.""" - return x - - def preprocess_labels(self, y): - """An overridable function which preprocesses labels.""" - return y def preprocess_samples(self, x, y=None, sample_weight=None): """An overridable function which preprocesses entire samples.""" - x = self.preprocess_features(x) - if y is not None: - y = self.preprocess_labels(y) return pack_x_y_sample_weight(x, y, sample_weight) # ======================================================================== @@ -184,10 +172,9 @@ def fit( ) x = _convert_inputs_to_dataset(x, y, sample_weight, batch_size) - if self.include_preprocessing: - x = x.map( - self.preprocess_samples, num_parallel_calls=tf.data.AUTOTUNE - ).prefetch(tf.data.AUTOTUNE) + x = x.map( + self.preprocess_samples, num_parallel_calls=tf.data.AUTOTUNE + ).prefetch(tf.data.AUTOTUNE) if validation_data is not None: if not isinstance(validation_data, tf.data.Dataset): @@ -221,10 +208,9 @@ def evaluate( # needs preprocessing. kwargs.pop("_use_cached_eval_dataset", None) x = _convert_inputs_to_dataset(x, y, sample_weight, batch_size) - if self.include_preprocessing: - x = x.map( - self.preprocess_samples, num_parallel_calls=tf.data.AUTOTUNE - ).prefetch(tf.data.AUTOTUNE) + x = x.map( + self.preprocess_samples, num_parallel_calls=tf.data.AUTOTUNE + ).prefetch(tf.data.AUTOTUNE) return super().evaluate( x=x, y=None, @@ -239,11 +225,9 @@ def predict( **kwargs, ): x = _convert_inputs_to_dataset(x, None, None, batch_size) - if self.include_preprocessing: - x = x.map( - self.preprocess_samples, num_parallel_calls=tf.data.AUTOTUNE - ).prefetch(tf.data.AUTOTUNE) - + x = x.map( + self.preprocess_samples, num_parallel_calls=tf.data.AUTOTUNE + ).prefetch(tf.data.AUTOTUNE) return super().predict( x=x, batch_size=None, @@ -257,14 +241,13 @@ def train_on_batch( sample_weight=None, **kwargs, ): - if self.include_preprocessing: - data = self.preprocess_samples(x, y, sample_weight) - x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) - x = ops.convert_to_tensor(x) - if y is not None: - y = ops.convert_to_tensor(y) - if sample_weight is not None: - sample_weight = ops.convert_to_tensor(sample_weight) + data = self.preprocess_samples(x, y, sample_weight) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) + x = ops.convert_to_tensor(x) + if y is not None: + y = ops.convert_to_tensor(y) + if sample_weight is not None: + sample_weight = ops.convert_to_tensor(sample_weight) return super().train_on_batch( x=x, y=y, @@ -279,14 +262,13 @@ def test_on_batch( sample_weight=None, **kwargs, ): - if self.include_preprocessing: - data = self.preprocess_samples(x, y, sample_weight) - x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) - x = ops.convert_to_tensor(x) - if y is not None: - y = ops.convert_to_tensor(y) - if sample_weight is not None: - sample_weight = ops.convert_to_tensor(sample_weight) + data = self.preprocess_samples(x, y, sample_weight) + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) + x = ops.convert_to_tensor(x) + if y is not None: + y = ops.convert_to_tensor(y) + if sample_weight is not None: + sample_weight = ops.convert_to_tensor(sample_weight) return super().test_on_batch( x=x, y=y, @@ -299,10 +281,9 @@ def predict_on_batch( x, **kwargs, ): - if self.include_preprocessing: - data = self.preprocess_samples(x) - x, _, _ = keras.utils.unpack_x_y_sample_weight(data) - x = ops.convert_to_tensor(x) + data = self.preprocess_samples(x) + x, _, _ = keras.utils.unpack_x_y_sample_weight(data) + x = ops.convert_to_tensor(x) return super().predict_on_batch( x=x, **kwargs, diff --git a/keras_nlp/utils/pipeline_model_test.py b/keras_nlp/utils/pipeline_model_test.py index 4c7c7f1964..ae71f9f570 100644 --- a/keras_nlp/utils/pipeline_model_test.py +++ b/keras_nlp/utils/pipeline_model_test.py @@ -36,8 +36,9 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.dense = keras.layers.Dense(1) - def preprocess_features(self, x): - return tf.strings.to_number(x) + def preprocess_samples(self, x, y=None, sample_weight=None): + x = tf.strings.to_number(x) + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) def call(self, inputs): return self.dense(inputs) @@ -48,8 +49,10 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.dense = keras.layers.Dense(1) - def preprocess_labels(self, y): - return tf.strings.to_number(y) + def preprocess_samples(self, x, y=None, sample_weight=None): + if y is not None: + y = tf.strings.to_number(y) + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) def call(self, inputs): return self.dense(inputs) @@ -63,8 +66,7 @@ def __init__(self, **kwargs): self.dense = keras.layers.Dense(1) def preprocess_samples(self, x, y=None, sample_weight=None): - x = tf.strings.to_number(x) - y = x + y = x = tf.strings.to_number(x) return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) def call(self, inputs): @@ -77,8 +79,9 @@ def __init__(self, **kwargs): outputs = keras.layers.Dense(1)(inputs) super().__init__(inputs, outputs, **kwargs) - def preprocess_features(self, x): - return tf.strings.to_number(x) + def preprocess_samples(self, x, y=None, sample_weight=None): + x = tf.strings.to_number(x) + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) def get_config(self): return {} @@ -167,19 +170,6 @@ def test_fit_with_preprocessing(self): model.fit(x=x, y=y, batch_size=8) model.fit(tf.data.Dataset.from_tensor_slices((x, y)).batch(8)) - def test_fit_no_preprocessing(self): - x = np.random.uniform(size=(100, 5)) - y = np.random.uniform(size=(100, 1)) - sw = np.random.uniform(size=(100, 1)) - model = FeaturePipeline(include_preprocessing=False) - model.compile(loss="mse") - # With sample weight. - model.fit(x=x, y=y, sample_weight=sw, batch_size=8) - model.fit(tf.data.Dataset.from_tensor_slices((x, y, sw)).batch(8)) - # Without sample weight. - model.fit(x=x, y=y, batch_size=8) - model.fit(tf.data.Dataset.from_tensor_slices((x, y)).batch(8)) - def test_evaluate_with_preprocessing(self): x = tf.strings.as_string(np.random.uniform(size=(100, 5))) y = np.random.uniform(size=(100, 1)) @@ -193,19 +183,6 @@ def test_evaluate_with_preprocessing(self): model.evaluate(x=x, y=y, batch_size=8) model.evaluate(tf.data.Dataset.from_tensor_slices((x, y)).batch(8)) - def test_evaluate_no_preprocessing(self): - x = np.random.uniform(size=(100, 5)) - y = np.random.uniform(size=(100, 1)) - sw = np.random.uniform(size=(100, 1)) - model = FeaturePipeline(include_preprocessing=False) - model.compile(loss="mse") - # With sample weight. - model.evaluate(x=x, y=y, sample_weight=sw, batch_size=8) - model.evaluate(tf.data.Dataset.from_tensor_slices((x, y, sw)).batch(8)) - # Without sample weight. - model.evaluate(x=x, y=y, batch_size=8) - model.evaluate(tf.data.Dataset.from_tensor_slices((x, y)).batch(8)) - def test_predict_with_preprocessing(self): x = tf.strings.as_string(np.random.uniform(size=(100, 5))) model = FeaturePipeline() @@ -213,13 +190,6 @@ def test_predict_with_preprocessing(self): model.predict(x=x, batch_size=8) model.predict(tf.data.Dataset.from_tensor_slices(x).batch(8)) - def test_predict_no_preprocessing(self): - x = np.random.uniform(size=(100, 5)) - model = FeaturePipeline(include_preprocessing=False) - model.compile(loss="mse") - model.predict(x=x, batch_size=8) - model.predict(tf.data.Dataset.from_tensor_slices(x).batch(8)) - def test_on_batch(self): x = tf.strings.as_string(np.random.uniform(size=(8, 5))) y = np.random.uniform(size=(8, 1)) @@ -234,19 +204,6 @@ def test_on_batch(self): model.test_on_batch(x=x, y=y) model.predict_on_batch(x=x) - def test_on_batch_no_preprocessing(self): - x = np.random.uniform(size=(8, 5)) - y = np.random.uniform(size=(8, 1)) - sw = np.random.uniform(size=(8, 1)) - model = FeaturePipeline(include_preprocessing=False) - model.compile(loss="mse") - # With sample weight. - model.train_on_batch(x=x, y=y, sample_weight=sw) - model.test_on_batch(x=x, y=y, sample_weight=sw) - # Without sample weight. - model.train_on_batch(x=x, y=y) - model.test_on_batch(x=x, y=y) - def test_saved_model(self): model = FeaturePipeline() x = tf.strings.as_string(np.random.uniform(size=(8, 5))) @@ -278,19 +235,6 @@ def test_fit_with_preprocessing(self): model.fit(x=x, y=y, batch_size=8) model.fit(tf.data.Dataset.from_tensor_slices((x, y)).batch(8)) - def test_fit_no_preprocessing(self): - x = np.random.uniform(size=(100, 5)) - y = np.random.uniform(size=(100, 1)) - sw = np.random.uniform(size=(100, 1)) - model = LabelPipeline(include_preprocessing=False) - model.compile(loss="mse") - # With sample weight. - model.fit(x=x, y=y, sample_weight=sw, batch_size=8) - model.fit(tf.data.Dataset.from_tensor_slices((x, y, sw)).batch(8)) - # Without sample weight. - model.fit(x=x, y=y, batch_size=8) - model.fit(tf.data.Dataset.from_tensor_slices((x, y)).batch(8)) - def test_evaluate_with_preprocessing(self): x = np.random.uniform(size=(100, 5)) y = tf.strings.as_string(np.random.uniform(size=(100, 1))) @@ -304,19 +248,6 @@ def test_evaluate_with_preprocessing(self): model.evaluate(x=x, y=y, batch_size=8) model.evaluate(tf.data.Dataset.from_tensor_slices((x, y)).batch(8)) - def test_evaluate_no_preprocessing(self): - x = np.random.uniform(size=(100, 5)) - y = np.random.uniform(size=(100, 1)) - sw = np.random.uniform(size=(100, 1)) - model = LabelPipeline(include_preprocessing=False) - model.compile(loss="mse") - # With sample weight. - model.evaluate(x=x, y=y, sample_weight=sw, batch_size=8) - model.evaluate(tf.data.Dataset.from_tensor_slices((x, y, sw)).batch(8)) - # Without sample weight. - model.evaluate(x=x, y=y, batch_size=8) - model.evaluate(tf.data.Dataset.from_tensor_slices((x, y)).batch(8)) - def test_predict_with_preprocessing(self): x = np.random.uniform(size=(100, 5)) model = LabelPipeline() @@ -338,20 +269,6 @@ def test_on_batch(self): model.test_on_batch(x=x, y=y) model.predict_on_batch(x=x) - def test_on_batch_no_preprocessing(self): - x = np.random.uniform(size=(8, 5)) - y = np.random.uniform(size=(8, 1)) - sw = np.random.uniform(size=(8, 1)) - model = LabelPipeline(include_preprocessing=False) - model.compile(loss="mse") - # With sample weight. - model.train_on_batch(x=x, y=y, sample_weight=sw) - model.test_on_batch(x=x, y=y, sample_weight=sw) - # Without sample weight. - model.train_on_batch(x=x, y=y) - model.test_on_batch(x=x, y=y) - model.predict_on_batch(x=x) - def test_saved_model(self): model = LabelPipeline() x = np.random.uniform(size=(8, 5)) @@ -377,14 +294,6 @@ def test_fit_with_preprocessing(self): model.fit(x=data, batch_size=8) model.fit(tf.data.Dataset.from_tensor_slices(data).batch(8)) - def test_fit_no_preprocessing(self): - x = np.random.uniform(size=(100, 1)) - y = np.random.uniform(size=(100, 1)) - model = DataPipeline(include_preprocessing=False) - model.compile(loss="mse") - model.fit(x=x, y=y, batch_size=8) - model.fit(tf.data.Dataset.from_tensor_slices((x, y)).batch(8)) - def test_evaluate_with_preprocessing(self): data = tf.strings.as_string(np.random.uniform(size=(100, 1))) model = DataPipeline() @@ -392,14 +301,6 @@ def test_evaluate_with_preprocessing(self): model.evaluate(x=data, batch_size=8) model.evaluate(tf.data.Dataset.from_tensor_slices(data).batch(8)) - def test_evaluate_no_preprocessing(self): - x = np.random.uniform(size=(100, 1)) - y = np.random.uniform(size=(100, 1)) - model = DataPipeline(include_preprocessing=False) - model.compile(loss="mse") - model.evaluate(x=x, y=y, batch_size=8) - model.evaluate(tf.data.Dataset.from_tensor_slices((x, y)).batch(8)) - def test_predict_with_preprocessing(self): x = tf.strings.as_string(np.random.uniform(size=(100, 1))) model = DataPipeline() @@ -407,13 +308,6 @@ def test_predict_with_preprocessing(self): model.predict(x=x, batch_size=8) model.predict(tf.data.Dataset.from_tensor_slices(x).batch(8)) - def test_predict_no_preprocessing(self): - x = np.random.uniform(size=(100, 1)) - model = DataPipeline(include_preprocessing=False) - model.compile(loss="mse") - model.predict(x=x, batch_size=8) - model.predict(tf.data.Dataset.from_tensor_slices(x).batch(8)) - def test_on_batch(self): data = tf.strings.as_string(np.random.uniform(size=(8, 1))) model = DataPipeline() @@ -426,20 +320,6 @@ def test_on_batch(self): model.test_on_batch(x=data) model.predict_on_batch(x=data) - def test_on_batch_no_preprocessing(self): - x = np.random.uniform(size=(8, 1)) - y = np.random.uniform(size=(8, 1)) - sw = np.random.uniform(size=(8, 1)) - model = DataPipeline(include_preprocessing=False) - model.compile(loss="mse") - # With sample weight. - model.train_on_batch(x=x, y=y, sample_weight=sw) - model.test_on_batch(x=x, y=y, sample_weight=sw) - # Without sample weight. - model.train_on_batch(x=x, y=y) - model.test_on_batch(x=x, y=y) - model.predict_on_batch(x=x) - def test_saved_model(self): model = DataPipeline() data = tf.strings.as_string(np.random.uniform(size=(8, 1))) @@ -472,19 +352,6 @@ def test_fit(self): model.fit(x=x, y=y, batch_size=8) model.fit(tf.data.Dataset.from_tensor_slices((x, y)).batch(8)) - def test_fit_no_preprocessing(self): - x = np.random.uniform(size=(100, 5)) - y = np.random.uniform(size=(100, 1)) - sw = np.random.uniform(size=(100, 1)) - model = FunctionalPipeline(include_preprocessing=False) - model.compile(loss="mse") - # With sample weight. - model.fit(x=x, y=y, sample_weight=sw, batch_size=8) - model.fit(tf.data.Dataset.from_tensor_slices((x, y, sw)).batch(8)) - # Without sample weight. - model.fit(x=x, y=y, batch_size=8) - model.fit(tf.data.Dataset.from_tensor_slices((x, y)).batch(8)) - def test_saved_model(self): model = FunctionalPipeline() x = tf.strings.as_string(np.random.uniform(size=(8, 5))) diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index 6bb2748fd9..01c11a3db1 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -16,6 +16,7 @@ import json import os +from keras_nlp.backend import config as backend_config from keras_nlp.backend import keras try: @@ -180,6 +181,13 @@ def load_from_preset( # Optionally load weights. load_weights = load_weights and config["weights"] if load_weights: + # For jax, delete all previous allocated memory to avoid temporarily + # duplicating variable allocations. torch and tensorflow have stateful + # variable types and do not need this fix. + if backend_config.backend() == "jax": + for weight in layer.weights: + if getattr(weight, "_value", None) is not None: + weight._value.delete() weights_path = get_file(preset, config["weights"]) layer.load_weights(weights_path) diff --git a/keras_nlp/version_utils.py b/keras_nlp/version_utils.py index 15fede3a08..4d6a8186d4 100644 --- a/keras_nlp/version_utils.py +++ b/keras_nlp/version_utils.py @@ -15,7 +15,7 @@ from keras_nlp.api_export import keras_nlp_export # Unique source of truth for the version number. -__version__ = "0.7.0" +__version__ = "0.8.0" @keras_nlp_export("keras_nlp.version") diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index c09b306264..903a603352 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. -tf-nightly-cpu==2.16.0.dev20231221 # Pin a working nightly until rc0. -tensorflow-text-nightly==2.16.0.dev20231221 # Pin a working nightly until rc0. +tf-nightly-cpu==2.16.0.dev20240201 # Pin a working nightly until rc0. +tensorflow-text-nightly==2.16.0.dev20240201 # Pin a working nightly until rc0. # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu @@ -8,7 +8,8 @@ torch>=2.1.0 torchvision>=0.16.0 # Jax with cuda support. +# TODO: 0.4.24 has an updated Cuda version breaks Jax CI. --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -jax[cuda12_pip] +jax[cuda12_pip]==0.4.23 -r requirements-common.txt diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index 21a8ed2463..be95915996 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow with cuda support. -tf-nightly[and-cuda]==2.16.0.dev20231221 # Pin a working nightly until rc0. -tensorflow-text-nightly==2.16.0.dev20231221 # Pin a working nightly until rc0. +tf-nightly[and-cuda]==2.16.0.dev20240201 # Pin a working nightly until rc0. +tensorflow-text-nightly==2.16.0.dev20240201 # Pin a working nightly until rc0. # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index c71c51e478..7ea2981478 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. -tf-nightly-cpu==2.16.0.dev20231221 # Pin a working nightly until rc0. -tensorflow-text-nightly==2.16.0.dev20231221 # Pin a working nightly until rc0. +tf-nightly-cpu==2.16.0.dev20240201 # Pin a working nightly until rc0. +tensorflow-text-nightly==2.16.0.dev20240201 # Pin a working nightly until rc0. # Torch with cuda support. --extra-index-url https://download.pytorch.org/whl/cu121 diff --git a/requirements.txt b/requirements.txt index fa1dc91943..b226229d15 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Tensorflow. -tf-nightly-cpu==2.16.0.dev20231221 # Pin a working nightly until rc0. -tensorflow-text-nightly==2.16.0.dev20231221 # Pin a working nightly until rc0. +tf-nightly-cpu==2.16.0.dev20240201 # Pin a working nightly until rc0. +tensorflow-text-nightly==2.16.0.dev20240201 # Pin a working nightly until rc0. # Torch. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/tools/checkpoint_conversion/convert_bloom_checkpoints.py b/tools/checkpoint_conversion/convert_bloom_checkpoints.py new file mode 100644 index 0000000000..38acd099cf --- /dev/null +++ b/tools/checkpoint_conversion/convert_bloom_checkpoints.py @@ -0,0 +1,248 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import huggingface_hub # noqa: E402 +import numpy as np # noqa: E402 +import torch # noqa: E402 +import transformers # noqa: E402 +from absl import app # noqa: E402 +from absl import flags # noqa: E402 + +import keras_nlp # noqa: E402 +from keras_nlp.models import BloomBackbone # noqa: E402 +from keras_nlp.models import BloomTokenizer # noqa: E402 + +FLAGS = flags.FLAGS + +PRESET_MAP = { + "bloom_560m_multi": "bigscience/bloom-560m", + "bloom_1.1b_multi": "bigscience/bloom-1b1", + "bloom_1.7b_multi": "bigscience/bloom-1b7", + "bloom_3b_multi": "bigscience/bloom-3b", + "bloom_7b_multi": "bigscience/bloom-7b1", + "bloom_176b_multi": "bigscience/bloom", +} + +EXTRACT_DIR = "./model" + + +flags.DEFINE_string( + "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' +) +flags.mark_flag_as_required("preset") + + +def download_hf_model(hf_model_name): + hf_model_dir = huggingface_hub.snapshot_download( + repo_id=hf_model_name, + allow_patterns=["*.json", "*.bin"], + ignore_patterns=["onnx/*"], + local_dir=EXTRACT_DIR, + ) + + return hf_model_dir + + +def convert_model(hf_model): + # get huggingface model configuration. + hf_config = hf_model.config.to_dict() + + kwargs = {} + kwargs["vocabulary_size"] = hf_config["vocab_size"] + kwargs["num_layers"] = hf_config["n_layer"] + kwargs["num_heads"] = hf_config["n_head"] + kwargs["hidden_dim"] = hf_config["hidden_size"] + kwargs["intermediate_dim"] = hf_config["hidden_size"] * 4 + kwargs["dropout"] = hf_config["hidden_dropout"] + kwargs["layer_norm_epsilon"] = hf_config["layer_norm_epsilon"] + + return BloomBackbone(**kwargs) + + +def convert_tokenizer(hf_model_dir): + tokenizer_file_path = os.path.join(hf_model_dir, "tokenizer.json") + with open(tokenizer_file_path) as tokenizer_file: + hf_tokenizer = json.load(tokenizer_file) + + vocab = hf_tokenizer["model"]["vocab"] + merges = hf_tokenizer["model"]["merges"] + + return BloomTokenizer(vocabulary=vocab, merges=merges) + + +def convert_weights(keras_model, hf_model): + hidden_dim = keras_model.hidden_dim + num_heads = keras_model.num_heads + head_dim = hidden_dim // num_heads + num_layers = keras_model.num_layers + + # get huggingface model weights. + hf_wts = hf_model.state_dict() + + # assign huggingface weights to the keras model. + # Embedding layer. + keras_model.get_layer("token_embedding").embeddings.assign( + hf_wts["word_embeddings.weight"] + ) + # LayerNorm. + keras_model.get_layer("token_embedding_layernorm").gamma.assign( + hf_wts["word_embeddings_layernorm.weight"] + ) + keras_model.get_layer("token_embedding_layernorm").beta.assign( + hf_wts["word_embeddings_layernorm.bias"] + ) + + keras_model.get_layer("final_layernorm").gamma.assign(hf_wts["ln_f.weight"]) + keras_model.get_layer("final_layernorm").beta.assign(hf_wts["ln_f.bias"]) + + # Decoder layers. + for i in range(num_layers): + decoder_layer = keras_model.get_layer(f"transformer_layer_{i}") + # LayrNorm. + decoder_layer._pre_attention_layernorm.gamma.assign( + hf_wts[f"h.{i}.input_layernorm.weight"] + ) + decoder_layer._pre_attention_layernorm.beta.assign( + hf_wts[f"h.{i}.input_layernorm.bias"] + ) + decoder_layer._post_attention_layernorm.gamma.assign( + hf_wts[f"h.{i}.post_attention_layernorm.weight"] + ) + decoder_layer._post_attention_layernorm.beta.assign( + hf_wts[f"h.{i}.post_attention_layernorm.bias"] + ) + + # Attention layer. + attention_layer = decoder_layer._self_attention_layer + + fused_qkv_kernal = hf_wts[ + f"h.{i}.self_attention.query_key_value.weight" + ].T + fused_qkv_kernal = fused_qkv_kernal.view( + hidden_dim, num_heads, 3, head_dim + ) + query_kernal = fused_qkv_kernal[..., 0, :] + key_kernal = fused_qkv_kernal[..., 1, :] + value_kernl = fused_qkv_kernal[..., 2, :] + + fused_qkv_bais = hf_wts[f"h.{i}.self_attention.query_key_value.bias"] + fused_qkv_bais = fused_qkv_bais.view(num_heads, 3, head_dim) + query_bais = fused_qkv_bais[:, 0, :] + key_bais = fused_qkv_bais[:, 1, :] + value_bais = fused_qkv_bais[:, 2, :] + + attention_layer._query_dense.kernel.assign(query_kernal) + attention_layer._query_dense.bias.assign(query_bais) + attention_layer._key_dense.kernel.assign(key_kernal) + attention_layer._key_dense.bias.assign(key_bais) + attention_layer._value_dense.kernel.assign(value_kernl) + attention_layer._value_dense.bias.assign(value_bais) + + attention_layer._output_dense.kernel.assign( + hf_wts[f"h.{i}.self_attention.dense.weight"].T + ) + attention_layer._output_dense.bias.assign( + hf_wts[f"h.{i}.self_attention.dense.bias"] + ) + + # mlp. + decoder_layer._mlp_intermediate_dense.kernel.assign( + hf_wts[f"h.{i}.mlp.dense_h_to_4h.weight"].T + ) + decoder_layer._mlp_intermediate_dense.bias.assign( + hf_wts[f"h.{i}.mlp.dense_h_to_4h.bias"] + ) + decoder_layer._mlp_output_dense.kernel.assign( + hf_wts[f"h.{i}.mlp.dense_4h_to_h.weight"].T + ) + decoder_layer._mlp_output_dense.bias.assign( + hf_wts[f"h.{i}.mlp.dense_4h_to_h.bias"] + ) + + +def validate_output( + hf_model, + keras_model, + hf_tokenizer, + keras_tokenizer, +): + input_str = ["the quick brown fox ran, galloped and jumped."] + + # KerasNLP + token_ids = torch.tensor(keras_tokenizer(input_str)) + padding_mask = token_ids != 3 + keras_model_input = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + keras_model_outputs = keras_model.predict(keras_model_input) + + hf_model_input = hf_tokenizer(input_str, return_tensors="pt") + + hf_model_outputs = hf_model(**hf_model_input).last_hidden_state + hf_model_outputs = hf_model_outputs.detach().numpy() + + # Comparing the outputs. + print("šŸ”¶ KerasNLP output:", keras_model_outputs[0, 0, :10]) + print("šŸ”¶ HF output:", hf_model_outputs[0, 0, :10]) + print("šŸ”¶ Difference:", np.mean(keras_model_outputs - hf_model_outputs)) + + +def main(_): + preset = FLAGS.preset + + assert ( + preset in PRESET_MAP.keys() + ), f'Invalid preset {preset}. Must be one of {",".join(PRESET_MAP.keys())}' + + print(f"āœ… Coverting {preset}") + + hf_model_name = PRESET_MAP[preset] + hf_model_dir = download_hf_model(hf_model_name) + print("āœ… Huggingface model downloaded from hub") + + hf_model = transformers.BloomModel.from_pretrained(hf_model_dir) + hf_tokenizer = transformers.BloomTokenizerFast.from_pretrained(hf_model_dir) + print("āœ… Huggingface model loaded") + + keras_model = convert_model(hf_model) + keras_tokenizer = convert_tokenizer(hf_model_dir) + print("āœ… Keras model loaded") + + convert_weights(keras_model, hf_model) + print("āœ… Weights converted") + + validate_output( + hf_model, + keras_model, + hf_tokenizer, + keras_tokenizer, + ) + print("āœ… Numerics validated") + + keras_nlp.src.utils.preset_utils.save_to_preset(keras_model, preset) + keras_nlp.src.utils.preset_utils.save_to_preset( + keras_tokenizer, preset, config_filename="tokenizer.json" + ) + print("āœ… Preset saved") + + +if __name__ == "__main__": + app.run(main) diff --git a/tools/checkpoint_conversion/convert_mistral_checkpoints.py b/tools/checkpoint_conversion/convert_mistral_checkpoints.py index 3bc443d910..8e10089efd 100644 --- a/tools/checkpoint_conversion/convert_mistral_checkpoints.py +++ b/tools/checkpoint_conversion/convert_mistral_checkpoints.py @@ -11,433 +11,342 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime +import gc import json +import os import pathlib -from dataclasses import dataclass -from pathlib import Path -from typing import Optional -from typing import Tuple - -import torch -from torch import nn - +import traceback + +import keras +import numpy as np +import requests +from absl import app +from absl import flags +from keras import ops +from transformers import AutoTokenizer +from transformers import MistralForCausalLM + +import keras_nlp from keras_nlp.models import MistralBackbone +from keras_nlp.models import MistralCausalLMPreprocessor +from keras_nlp.models import MistralTokenizer -MODEL_PATH = pathlib.Path("mistral-7B-v0.1") - -# Torch model taken from: -# https://github.com/mistralai/mistral-src/blob/147c4e68279b90eb61b19bdea44e16f5539d5a5d/one_file_ref.py - - -@dataclass -class ModelArgs: - dim: int - n_layers: int - head_dim: int - hidden_dim: int - n_heads: int - n_kv_heads: int - sliding_window: int - norm_eps: float - vocab_size: int - - max_batch_size: int = 0 - - -def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int): - keys = torch.repeat_interleave(keys, repeats=repeats, dim=2) - values = torch.repeat_interleave(values, repeats=repeats, dim=2) - return keys, values - - -def _reshape_for_broadcast( - freqs_cis: torch.Tensor, x: torch.Tensor -) -> torch.Tensor: - """ - freqs_cis: complex - (seq_len, head_dim / 2) - x: complex - (bsz, seq_len, head_dim / 2) - """ - ndim = x.ndim - assert 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( - freqs_cis.shape, - (x.shape[1], x.shape[-1]), - ) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.n_heads: int = args.n_heads - self.n_kv_heads: int = args.n_kv_heads - - self.repeats = self.n_heads // self.n_kv_heads - self.sliding_window = self.args.sliding_window - - self.scale = self.args.head_dim**-0.5 - - self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) - self.wk = nn.Linear( - args.dim, args.n_kv_heads * args.head_dim, bias=False - ) - self.wv = nn.Linear( - args.dim, args.n_kv_heads * args.head_dim, bias=False - ) - self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) - self.cache_k = torch.empty( - ( - args.max_batch_size, - args.sliding_window, - self.n_kv_heads, - self.args.head_dim, - ), - dtype=torch.float16, - ) - self.cache_v = torch.empty( - ( - args.max_batch_size, - args.sliding_window, - self.n_kv_heads, - self.args.head_dim, - ), - dtype=torch.float16, - ) - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - positions: torch.Tensor, - mask: Optional[torch.Tensor], - ) -> torch.Tensor: - bsz, seqlen, _ = x.shape - - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.view(bsz, seqlen, self.n_heads, self.args.head_dim) - xk = xk.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim) - xv = xv.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim) - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - - # The cache is a rotating buffer - scatter_pos = (positions[-self.sliding_window :] % self.sliding_window)[ - None, :, None, None - ] - scatter_pos = scatter_pos.repeat( - bsz, 1, self.n_kv_heads, self.args.head_dim - ) - self.cache_k[:bsz].scatter_( - dim=1, - index=scatter_pos, - src=xk[:, -self.sliding_window :].to(self.cache_k.dtype), - ) - self.cache_v[:bsz].scatter_( - dim=1, - index=scatter_pos, - src=xv[:, -self.sliding_window :].to(self.cache_v.dtype), - ) - - if positions.shape[0] > 1: - # prefill - key, value = repeat_kv(xk, xv, self.repeats) - else: - cur_pos = positions[-1].item() + 1 - key, value = repeat_kv( - self.cache_k[:bsz, :cur_pos, ...].to(xk.dtype), - self.cache_v[:bsz, :cur_pos, ...].to(xv.dtype), - self.repeats, - ) - - query = xq.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - # scores : [bsz, n_heads, seqlen | 1, seqlen] - scores = torch.matmul(query, key.transpose(2, 3)) * self.scale - - if mask is not None: - scores += mask[None, None, ...] - - scores = scores.float() - scores = nn.functional.softmax(scores, dim=-1).type_as(query) - output = torch.matmul( - scores, value - ) # (bs, n_local_heads, slen, head_dim) - output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) - return self.wo(output) - - -class FeedForward(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) - self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) - self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) - - def forward(self, x) -> torch.Tensor: - return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.n_heads = args.n_heads - self.dim = args.dim - self.attention = Attention(args) - self.feed_forward = FeedForward(args=args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.args = args - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - positions: torch.Tensor, - mask: Optional[torch.Tensor], - ) -> torch.Tensor: - r = self.attention.forward( - self.attention_norm(x), freqs_cis, positions, mask - ) - h = x + r - r = self.feed_forward.forward(self.ffn_norm(h)) - out = h + r - return out - - -def precompute_freqs_cis( - dim: int, end: int, theta: float = 10000.0 -) -> torch.Tensor: - freqs = 1.0 / ( - theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) - ) - t = torch.arange(end, device=freqs.device) # type: ignore - freqs = torch.outer(t, freqs).float() # type: ignore - return torch.polar(torch.ones_like(freqs), freqs) # complex64 +PRESET_MAP = { + "mistral_7b_en": "mistralai/Mistral-7B-v0.1", + "mistral_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.1", +} +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' +) -class TorchTransformer(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.n_layers = args.n_layers - assert self.vocab_size > 0 - self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) +def convert_checkpoints(keras_nlp_model, hf_model): + config = hf_model.config - self.layers = torch.nn.ModuleList( - [TransformerBlock(args=args) for _ in range(args.n_layers)] - ) - - self.norm = RMSNorm(args.dim, eps=args.norm_eps) - - self.output = nn.Linear(args.dim, args.vocab_size, bias=False) - - self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - ): - h = self.tok_embeddings(input_ids) - freqs_cis = self.freqs_cis[positions] - - mask: Optional[torch.Tensor] = None - if input_ids.shape[1] > 1: - seqlen = input_ids.shape[1] - tensor = torch.full( - (seqlen, seqlen), - dtype=h.dtype, - fill_value=1, - device=h.device, - ) - mask = torch.tril(tensor, diagonal=0).to(h.dtype) - # make the mask banded to account for sliding window - mask = torch.triu(mask, diagonal=-self.args.sliding_window) - mask = torch.log(mask) - - for layer in self.layers: - h = layer(h, freqs_cis, positions, mask) - - return self.output(self.norm(h)).float() - - @staticmethod - def from_folder( - folder: Path, max_batch_size: int = 1, device="cpu", dtype=torch.float16 - ): - with open(folder / "params.json", "r") as f: - model_args = ModelArgs(**json.loads(f.read())) - model_args.max_batch_size = max_batch_size - model = TorchTransformer(model_args).to(device=device, dtype=dtype) - loaded = torch.load(folder / "consolidated.00.pth") - model.load_state_dict(loaded) - return model - - -def port_weights( - model_k3: MistralBackbone, model_torch: TorchTransformer, params: ModelArgs -): - model_k3.get_layer("token_embedding").embeddings.assign( - model_torch.tok_embeddings.weight.detach().cpu().numpy() + keras_nlp_model.token_embedding.embeddings.assign( + hf_model.model.embed_tokens.weight.detach().cpu().numpy() ) - for i in range(model_k3.num_layers): - model_k3.get_layer( - f"transformer_layer_{i}" - )._self_attention_layer._key_dense.set_weights( + for i in range(keras_nlp_model.num_layers): + keras_nlp_model.transformer_layers[ + i + ]._self_attention_layer._key_dense.set_weights( [ - model_torch.layers[i] - .attention.wk.weight.T.reshape( - params.dim, params.n_kv_heads, params.head_dim + hf_model.model.layers[i] + .self_attn.k_proj.weight.T.reshape( + config.hidden_size, + config.num_key_value_heads, + config.hidden_size // config.num_attention_heads, ) .detach() .cpu() .numpy() ] ) - model_k3.get_layer( - f"transformer_layer_{i}" - )._self_attention_layer._query_dense.set_weights( + keras_nlp_model.transformer_layers[ + i + ]._self_attention_layer._query_dense.set_weights( [ - model_torch.layers[i] - .attention.wq.weight.T.reshape( - params.dim, params.n_heads, params.head_dim + hf_model.model.layers[i] + .self_attn.q_proj.weight.T.reshape( + config.hidden_size, + config.num_attention_heads, + config.hidden_size // config.num_attention_heads, ) .detach() .cpu() .numpy() ] ) - model_k3.get_layer( - f"transformer_layer_{i}" - )._self_attention_layer._value_dense.set_weights( + keras_nlp_model.transformer_layers[ + i + ]._self_attention_layer._value_dense.set_weights( [ - model_torch.layers[i] - .attention.wv.weight.T.reshape( - params.dim, params.n_kv_heads, params.head_dim + hf_model.model.layers[i] + .self_attn.v_proj.weight.T.reshape( + config.hidden_size, + config.num_key_value_heads, + config.hidden_size // config.num_attention_heads, ) .detach() .cpu() .numpy() ] ) - model_k3.get_layer( - f"transformer_layer_{i}" - )._self_attention_layer._output_dense.set_weights( + keras_nlp_model.transformer_layers[ + i + ]._self_attention_layer._output_dense.set_weights( [ - model_torch.layers[i] - .attention.wo.weight.T.reshape( - params.n_heads, params.head_dim, params.dim + hf_model.model.layers[i] + .self_attn.o_proj.weight.T.reshape( + config.num_attention_heads, + config.hidden_size // config.num_attention_heads, + config.hidden_size, ) .detach() .cpu() .numpy() ] ) - model_k3.get_layer( - f"transformer_layer_{i}" - )._self_attention_layernorm.set_weights( - [model_torch.layers[i].attention_norm.weight.detach().cpu().numpy()] + keras_nlp_model.transformer_layers[ + i + ]._self_attention_layernorm.set_weights( + [ + hf_model.model.layers[i] + .input_layernorm.weight.detach() + .cpu() + .numpy() + ] ) - model_k3.get_layer( - f"transformer_layer_{i}" - )._feedforward_intermediate_dense.set_weights( + keras_nlp_model.transformer_layers[ + i + ]._feedforward_intermediate_dense.set_weights( [ - model_torch.layers[i] - .feed_forward.w3.weight.T.detach() + hf_model.model.layers[i] + .mlp.up_proj.weight.T.detach() .cpu() .numpy() ] ) - model_k3.get_layer( - f"transformer_layer_{i}" - )._feedforward_output_dense.set_weights( + keras_nlp_model.transformer_layers[ + i + ]._feedforward_output_dense.set_weights( [ - model_torch.layers[i] - .feed_forward.w2.weight.T.detach() + hf_model.model.layers[i] + .mlp.down_proj.weight.T.detach() .cpu() .numpy() ] ) - model_k3.get_layer( - f"transformer_layer_{i}" - )._feedforward_gate_dense.set_weights( + keras_nlp_model.transformer_layers[ + i + ]._feedforward_gate_dense.set_weights( [ - model_torch.layers[i] - .feed_forward.w1.weight.T.detach() + hf_model.model.layers[i] + .mlp.gate_proj.weight.T.detach() .cpu() .numpy() ] ) - model_k3.get_layer( - f"transformer_layer_{i}" - )._feedforward_layernorm.set_weights( - [model_torch.layers[i].ffn_norm.weight.detach().cpu().numpy()] + keras_nlp_model.transformer_layers[ + i + ]._feedforward_layernorm.set_weights( + [ + hf_model.model.layers[i] + .post_attention_layernorm.weight.detach() + .cpu() + .numpy() + ] ) - model_k3.get_layer("sequence_output_layernorm").set_weights( - [model_torch.norm.weight.detach().cpu().numpy()] + keras_nlp_model.layer_norm.set_weights( + [hf_model.model.norm.weight.detach().cpu().numpy()] ) - model_k3.get_layer("token_embedding").reverse_embeddings.assign( - model_torch.output.weight.T.detach().cpu().numpy() + keras_nlp_model.token_embedding.reverse_embeddings.assign( + hf_model.lm_head.weight.T.detach().cpu().numpy() ) -if __name__ == "__main__": - with open(MODEL_PATH / "params.json", "r") as params_file: - params = ModelArgs(**json.load(params_file)) +def test_model( + keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_model_tokenizer +): + # First, test that the number of parameters match + keras_nlp_params = keras_nlp_model.count_params() + hf_params = hf_model.num_parameters() + assert keras_nlp_params == hf_params + + # Test the outputs of both the models + hf_outputs = hf_model( + **hf_model_tokenizer(["What is Keras?"], return_tensors="pt") + ) + hf_output_logits = hf_outputs.logits.detach().cpu().numpy() - model_torch = TorchTransformer.from_folder( - MODEL_PATH, device="cpu", dtype=torch.float16 + keras_nlp_preprocessor = MistralCausalLMPreprocessor(keras_nlp_tokenizer) + keras_nlp_output = keras_nlp_model( + keras_nlp_preprocessor(["What is Keras?"], sequence_length=6)[0] ) - print("Torch model loaded") - model_k3 = MistralBackbone( - vocabulary_size=32000, - hidden_dim=4096, - num_layers=32, - num_query_heads=32, - num_key_value_heads=8, - intermediate_dim=14336, - sliding_window=4096, - layer_norm_epsilon=1e-6, - dtype="float16", + keras_nlp_logits = keras_nlp_model.token_embedding( + keras_nlp_output, reverse=True ) - print("Keras 3 model loaded.") + keras_nlp_logits = ops.convert_to_numpy(keras_nlp_logits) + + # High tolerence since bfloat16 is used as the default dtype for Mistral + try: + np.testing.assert_allclose( + keras_nlp_logits, hf_output_logits, atol=1e-4 + ) + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") + + +def test_tokenizer(keras_nlp_tokenizer, hf_tokenizer): + hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") + hf_output = hf_output["input_ids"].detach().cpu().numpy() + keras_nlp_preprocessor = MistralCausalLMPreprocessor(keras_nlp_tokenizer) + keras_nlp_output = keras_nlp_preprocessor( + ["What is Keras?"], sequence_length=6 + ) + keras_nlp_output = ops.convert_to_numpy(keras_nlp_output[0]["token_ids"]) + + np.testing.assert_equal(keras_nlp_output, hf_output) - port_weights(model_k3, model_torch, params) - print("Weight transfer done.") - model_k3.save_weights("mistral_7b.weights.h5") - print("Weights saved.") +def main(_): + # === Get the preset name === + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + + # === Create the save directories === + model_dir = pathlib.Path(__file__).parent / f"{preset}" + tokenizer_dir = model_dir / "assets" / "tokenizer" + if not model_dir.exists(): + os.makedirs(model_dir) + if not tokenizer_dir.exists(): + os.makedirs(tokenizer_dir) + + # === Load the Huggingface model === + hf_model = MistralForCausalLM.from_pretrained(hf_preset) + hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset) + hf_model.eval() + print("\n-> Huggingface model and tokenizer loaded") + + # === Load the KerasNLP model === + keras_nlp_config = dict( + vocabulary_size=hf_model.config.vocab_size, + hidden_dim=hf_model.config.hidden_size, + num_layers=hf_model.config.num_hidden_layers, + num_query_heads=hf_model.config.num_attention_heads, + num_key_value_heads=hf_model.config.num_key_value_heads, + intermediate_dim=hf_model.config.intermediate_size, + sliding_window=hf_model.config.sliding_window, + layer_norm_epsilon=hf_model.config.rms_norm_eps, + rope_max_wavelength=hf_model.config.rope_theta, + dtype="float32", + ) + keras_nlp_model = MistralBackbone(**keras_nlp_config) + + # === Download the tokenizer from Huggingface model card === + spm_path = ( + f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model" + ) + response = requests.get(spm_path) + if not response.ok: + raise ValueError(f"Couldn't fetch {preset}'s tokenizer.") + tokenizer_path = tokenizer_dir / "vocabulary.spm" + with open(tokenizer_path, "wb") as tokenizer_file: + tokenizer_file.write(response.content) + keras_nlp_tokenizer = MistralTokenizer(str(tokenizer_path.absolute())) + print("\n-> Keras 3 model and tokenizer loaded.") + + # === Port the weights === + convert_checkpoints(keras_nlp_model, hf_model) + print("\n-> Weight transfer done.") + + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_nlp_tokenizer, hf_tokenizer) + test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer) + print("\n-> Tests passed!") + + # === Save the model weights in float32 format === + keras_nlp_model.save_weights( + str((model_dir / "model.weights.h5").absolute()) + ) + print("\n-> Saved the model weights in float16") + + del keras_nlp_model, hf_model + gc.collect() + + keras_nlp_config["dtype"] = "float16" + + # === Save the weights again in float16 === + keras_nlp_model = MistralBackbone(**keras_nlp_config) + keras_nlp_model.load_weights( + str((model_dir / "model.weights.h5").absolute()) + ) + keras_nlp_model.save_weights( + str((model_dir / "model.weights.h5").absolute()) + ) + print("-> Saved the model weights in float16") + + # === Save the model config === + keras_nlp_config["dtype"] = "bfloat16" + model_config = { + "module": "keras_nlp.src.models.mistral.mistral_backbone", + "class_name": "MistralBackbone", + "config": {**keras_nlp_config}, + "registered_name": "keras_nlp>MistralBackbone", + "assets": [], + "weights": "model.weights.h5", + } + model_config_json = json.dumps(model_config) + with open(model_dir / "config.json", "w") as model_config_file: + model_config_file.write(model_config_json) + print("\n-> Saved model config") + + # === Save the tokenizer config === + tokenizer_config = { + "module": "keras_nlp.src.models.mistral.Mistral_tokenizer", + "class_name": "MistralTokenizer", + "config": { + "name": "mistral_tokenizer", + "trainable": True, + "dtype": "int32", + "proto": None, + "sequence_length": None, + }, + "registered_name": "keras_nlp>MistralTokenizer", + "assets": ["assets/tokenizer/vocabulary.spm"], + "weights": None, + } + tokenizer_config_json = json.dumps(tokenizer_config) + with open(model_dir / "tokenizer.json", "w") as tokenizer_config_file: + tokenizer_config_file.write(tokenizer_config_json) + print("\n-> Saved tokenizer config") + + # === Save metadata === + metadata_config = { + "keras_version": keras.__version__, + "keras_nlp_version": keras_nlp.__version__, + "parameter_count": keras_nlp_model.count_params(), + "date_saved": datetime.datetime.utcnow().strftime("%Y-%m-%d@%H:%M:%S"), + } + metadata_config_json = json.dumps(metadata_config) + with open(model_dir / "metadata.json", "w") as metadata_config_file: + metadata_config_file.write(metadata_config_json) + print("\n-> Saved metadata") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) diff --git a/tools/gemma/export_gemma_to_hf.py b/tools/gemma/export_gemma_to_hf.py new file mode 100644 index 0000000000..31e3f3c69b --- /dev/null +++ b/tools/gemma/export_gemma_to_hf.py @@ -0,0 +1,328 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +import torch +import transformers +from absl import app +from absl import flags + +import keras_nlp + +os.environ["KERAS_BACKEND"] = "torch" + +""" +Sample usage: + +For converting a keras model to HuggingFace format using a custom or fine-tuned +checkpoint from Keras, make sure to pass the path for the Keras weights file +(ending in `.weights.h5`), the model size (`2b` or `7b`), and the tokenizer +vocabulary file (`.spm`, `.model`, or equivalent) to +`--weights_file`, `--size`, and `--vocab_path`, respectively. + +Optionally, you can specify the output directory +for the converted model at `--output_dir`. (defaults to `gg_hf`) +``` +python tools/gemma/export_gemma_to_hf.py \ + --weights_file fine_tuned_imdb.weights.h5 \ + --size 2b \ + --vocab_path gemma_lm_tokenizer/vocabulary.spm \ + --output_dir fine_tuned_gg_hf +``` + +For converting a Keras model to HuggingFace format from a preset, +simply pass the Keras preset name to `--preset` and its model size +(`2b` or `7b`) to `--size`. +``` +python tools/gemma/export_gemma_to_hf.py \ + --preset gemma_2b_en \ + --size 2b \ + --output_dir keras_hf_model/ +``` +""" + + +PRESET_MAP = { + "gemma_2b_en": "gg-hf/gemma-2b", + "gemma_instruct_2b_en": "gg-hf/gemma-2b", + "gemma_7b_en": "gg-hf/gemma-7b", + "gemma_instruct_7b_en": "gg-hf/gemma-7b", +} + +SIZE_MAP = { + "2b": ("gg-hf/gemma-2b", "gemma_2b_en"), + "7b": ("gg-hf/gemma-7b", "gemma_7b_en"), +} + +gemma_2b_config = transformers.GemmaConfig( + num_hidden_layers=18, + num_attention_heads=8, + num_key_value_heads=1, + hidden_size=2048, + intermediate_size=16384, +) + +gemma_7b_config = transformers.GemmaConfig() + +CONFIG_MAPPING = {"2b": gemma_2b_config, "7b": gemma_7b_config} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "hf_token", + None, + "Your HuggingFace token. Needed for access to the HuggingFace Gemma" + "implementation since the repository is private, for now.", +) +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}' + " Alternatively, a Keras weights file (`.weights.h5`) can be passed" + " to --weights_file flag.", +) +flags.DEFINE_string( + "weights_file", + None, + "A Keras weights file (`.weights.h5`)." + " Alternatively, a model preset can be passed to --preset flag.", +) +flags.DEFINE_string( + "size", + None, + "Size of model. Must be passed if `weights_file` is passed. " + "This should be either `2b` or `7b`.", +) +flags.DEFINE_string( + "output_dir", + "gg_hf", + "An output directory for the converted HuggingFace model and tokenizer.", +) +flags.DEFINE_string( + "vocab_path", + None, + "A path containing the vocabulary (must be a `.spm` file or equivalent). " + "If not passed, the vocabulary of the preset will be used.", +) + + +def convert_checkpoints(preset, weights_file, size, output_dir, vocab_path): + if preset is not None: + hf_id = PRESET_MAP[preset] + print(f"\n-> Loading KerasNLP Gemma model with preset `{preset}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset(preset) + else: + hf_id, keras_preset = SIZE_MAP[size.lower()] + print(f"\n-> Loading Keras weights from file `{weights_file}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset( + keras_preset + ) + keras_nlp_model.load_weights(weights_file) + + print(f"\n-> Loading HuggingFace Gemma `{size.upper()}` model...") + hf_model = transformers.GemmaForCausalLM(CONFIG_MAPPING[size.lower()]) + + print("\nāœ… Model loading complete.") + print("\n-> Converting weights from KerasNLP Gemma to HuggingFace Gemma...") + + # Token embedding (with vocab size difference handling) + keras_embedding = keras_nlp_model.backbone.token_embedding.weights[0] + hf_vocab_size = hf_model.model.embed_tokens.weight.shape[0] + keras_nlp_vocab_size = keras_embedding.value.shape[0] + if hf_vocab_size < keras_nlp_vocab_size: + diff = keras_nlp_vocab_size - hf_vocab_size + update_state_dict( + hf_model.model.embed_tokens, + "weight", + keras_embedding.value[:-diff, :], + ) + else: + update_state_dict( + hf_model.model.embed_tokens, + "weight", + keras_embedding.value, + ) + + # Decoder blocks + for i in range(keras_nlp_model.backbone.num_layers): + decoder_block = keras_nlp_model.backbone.get_layer(f"decoder_block_{i}") + + # Pre-attention norm + update_state_dict( + hf_model.model.layers[i].input_layernorm, + "weight", + decoder_block.pre_attention_norm.weights[0].value, + ) + + # Attention + query_target_shape = hf_model.model.layers[ + i + ].self_attn.q_proj.weight.shape + query_tensor = decoder_block.attention.query_dense.weights[0].value + query_tensor = query_tensor.transpose(1, 2).reshape(query_target_shape) + update_state_dict( + hf_model.model.layers[i].self_attn.q_proj, "weight", query_tensor + ) + + key_target_shape = hf_model.model.layers[ + i + ].self_attn.k_proj.weight.shape + key_tensor = decoder_block.attention.key_dense.weights[0].value + key_tensor = key_tensor.transpose(1, 2).reshape(key_target_shape) + update_state_dict( + hf_model.model.layers[i].self_attn.k_proj, "weight", key_tensor + ) + + value_target_shape = hf_model.model.layers[ + i + ].self_attn.v_proj.weight.shape + value_tensor = decoder_block.attention.value_dense.weights[0].value + value_tensor = value_tensor.transpose(1, 2).reshape(value_target_shape) + update_state_dict( + hf_model.model.layers[i].self_attn.v_proj, "weight", value_tensor + ) + + out_target_shape = hf_model.model.layers[ + i + ].self_attn.o_proj.weight.shape + keras_out_tensor = decoder_block.attention.output_dense.weights[0].value + out_tensor = keras_out_tensor.reshape( + (out_target_shape[1], out_target_shape[0]) # Transpose target size + ).transpose(0, 1) + + update_state_dict( + hf_model.model.layers[i].self_attn.o_proj, "weight", out_tensor + ) + + # Post-attention norm + update_state_dict( + hf_model.model.layers[i].post_attention_layernorm, + "weight", + decoder_block.pre_ffw_norm.weights[0].value, + ) + + # MLP (Feed-forward) + update_state_dict( + hf_model.model.layers[i].mlp.gate_proj, + "weight", + decoder_block.gating_ffw.weights[0].value.transpose(0, 1), + ) + update_state_dict( + hf_model.model.layers[i].mlp.up_proj, + "weight", + decoder_block.gating_ffw_2.weights[0].value.transpose(0, 1), + ) + update_state_dict( + hf_model.model.layers[i].mlp.down_proj, + "weight", + decoder_block.ffw_linear.weights[0].value.transpose(0, 1), + ) + + # Final norm + update_state_dict( + hf_model.model.norm, + "weight", + keras_nlp_model.backbone.layers[-1].weights[0].value, + ) + + print("\nāœ… Weights converted successfully.") + print(f"\n-> Saving HuggingFace model to `{output_dir}`...") + + # Save model to HF Transformers format + os.makedirs(output_dir, exist_ok=True) + hf_model.save_pretrained(output_dir) + + print(f"\nāœ… Saving complete. Model saved at `{output_dir}`.") + + # Tokenizer + + if not vocab_path: + tokenizer_preset = preset or SIZE_MAP[size.lower()] + print( + "\n-> Loading KerasNLP Gemma tokenizer with " + f"preset `{tokenizer_preset}`..." + ) + keras_nlp_tokenizer = keras_nlp.models.GemmaTokenizer.from_preset( + tokenizer_preset + ) + # Save tokenizer state + keras_nlp_tokenizer.save_assets(output_dir) + vocab_path = os.path.join(output_dir, "vocabulary.spm") + print("\nāœ… Tokenizer loading complete.") + + hf_tokenizer = transformers.GemmaTokenizer(vocab_path) + + print(f"\n-> Saving HuggingFace Gemma tokenizer to `{output_dir}`...") + # Save tokenizer to HF Transformers format + hf_tokenizer.save_pretrained(output_dir) + + print(f"\nāœ… Saving complete. Tokenizer saved at `{output_dir}`.") + + +def update_state_dict(layer, weight_name: str, tensor: torch.Tensor) -> None: + """Updates the state dict for a weight given a tensor.""" + assert ( + tensor.shape == layer.state_dict()[weight_name].shape + ), f"{tensor.shape} vs {layer.state_dict()[weight_name].shape}" + layer.state_dict()[weight_name].copy_(tensor) + + +def flag_error_handler(): + if not FLAGS.preset and not FLAGS.weights_file: + raise ValueError( + "Please pass either a valid Keras preset to `--preset`" + " or supply a Keras weights file (`.weights.h5`) and model size" + " (`2b` or `7b`) to `--weights_file` and `--size`, respectively." + ) + if FLAGS.weights_file: + if FLAGS.preset: + raise ValueError( + "Both `--preset` and `--weights_file` flags cannot be supplied " + "at the same time. Either supply a valid Keras preset to " + "`--preset`or supply a Keras `.weights.h5` file and " + "model size (`2b` or `7b`) to `--weights_file` and `--size`, " + "respectively." + ) + if not str(FLAGS.weights_file).endswith(".weights.h5"): + raise ValueError( + "Please pass a valid Keras weights file ending in `.weights.h5`." + ) + if not FLAGS.size: + raise ValueError( + "The `size` flag must be passed if a weights file is passed. " + "Please pass the appropriate size (`2b` or `7b`) for your " + "model to the `--size` flag." + ) + if FLAGS.size.lower() not in ["2b", "7b"]: + raise ValueError( + "Invalid `size`. Please pass the appropriate size (`2b` or `7b`) " + "for your model to the `--size` flag." + ) + + +def main(_): + flag_error_handler() + convert_checkpoints( + FLAGS.preset, + FLAGS.weights_file, + FLAGS.size, + FLAGS.output_dir, + FLAGS.vocab_path, + ) + + +if __name__ == "__main__": + flags.mark_flag_as_required("size") + app.run(main) diff --git a/tools/gemma/export_gemma_to_torch_xla.py b/tools/gemma/export_gemma_to_torch_xla.py new file mode 100644 index 0000000000..005eac272d --- /dev/null +++ b/tools/gemma/export_gemma_to_torch_xla.py @@ -0,0 +1,322 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import os + +import gemma +import torch +import torch_xla.core.xla_model as xm +from absl import app +from absl import flags +from gemma import model_xla as gemma_model + +import keras_nlp + +os.environ["KERAS_BACKEND"] = "torch" + +""" +Sample usage: + +For converting a Keras model to PyTorch format using a custom or fine-tuned +checkpoint from Keras, make sure to pass the path for the Keras weights file +(ending in `.weights.h5`) and the model size (`2b` or `7b`) to `--weights_file` +and `--size`, respectively. + +Optionally, you can specify the output path for the converted model at +`--output_file`. (This defaults to `gemma.ckpt`) +``` +python tools/gemma/export_gemma_to_torch_xla.py \ + --weights_file fine_tuned_imdb.weights.h5 \ + --size 2b \ + --output_file fine_tuned_imdb.ckpt +``` + +For converting a Keras model to PyTorch format from a preset, +simply pass the Keras preset name to `--preset`. +``` +python tools/gemma/export_gemma_to_torch_xla.py \ + --preset gemma_2b_en \ + --output_file path/to/keras_torch_model.ckpt +``` +""" + + +PRESET_MAP = { + "gemma_2b_en": gemma.config.get_config_for_2b(), + "gemma_instruct_2b_en": gemma.config.get_config_for_2b(), + "gemma_7b_en": gemma.config.get_config_for_7b(), + "gemma_instruct_7b_en": gemma.config.get_config_for_7b(), +} + +SIZE_MAP = { + "2b": (gemma.config.get_config_for_2b(), "gemma_2b_en"), + "7b": (gemma.config.get_config_for_7b(), "gemma_7b_en"), +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}' + " Alternatively, a Keras weights file (`.weights.h5`) can be passed" + " to --weights_file flag.", +) +flags.DEFINE_string( + "weights_file", + None, + "A Keras weights file (`.weights.h5`)." + " Alternatively, a model preset can be passed to --preset flag.", +) +flags.DEFINE_string( + "size", + None, + "Size of model. Must be passed if `weights_file` is passed. " + "This should be either `2b` or `7b`.", +) +flags.DEFINE_string( + "output_file", + "gemma.ckpt", + "An output file for the converted PyTorch checkpoint. Default: `gemma.ckpt`", +) +flags.DEFINE_string( + "vocab_dir", + "gemma_tokenizer", + "A directory in which the vocabulary for the tokenizer will be stored.", +) +flags.DEFINE_string( + "dtype", + "float32", + "Set the precision of the converted checkpoint. Must be a valid PyTorch dtype.", +) + + +@contextlib.contextmanager +def _set_default_tensor_type(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(torch.float) + + +def _reconcile_attention_dims(qkv, target_shape): + return torch.cat(qkv).reshape(tuple(target_shape)) + + +def convert_checkpoints(preset, weights_file, size, output_file, vocab_dir): + device = xm.xla_device() + + if preset is not None: + print( + f"\n-> Loading PyTorch Gemma model config for preset `{preset}`..." + ) + model = gemma_model.GemmaForCausalLM( + PRESET_MAP[preset], world_size=1, rank=0, device=device + ) + print(f"\n-> Loading KerasNLP Gemma model with preset `{preset}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset(preset) + else: + print(f"\n-> Loading PyTorch Gemma model config for `{size}` model...") + config, size_preset = SIZE_MAP[size.lower()] + model = gemma_model.GemmaForCausalLM( + config, world_size=1, rank=0, device=device + ) + print(f"\n-> Loading Keras weights from file `{weights_file}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset( + size_preset + ) + keras_nlp_model.load_weights(weights_file) + + print("\nāœ… Model loading complete.") + print("\n-> Converting weights from KerasNLP Gemma to PyTorch Gemma...") + + # Token embedding (with vocab size difference handling) + keras_embedding = keras_nlp_model.backbone.token_embedding.weights[0] + torch_vocab_size = model.embedder.weight.shape[0] + keras_nlp_vocab_size = keras_embedding.value.shape[0] + if torch_vocab_size < keras_nlp_vocab_size: + diff = keras_nlp_vocab_size - torch_vocab_size + update_state_dict( + model.embedder, + "weight", + keras_embedding.value[:-diff, :], + ) + else: + update_state_dict( + model.embedder, + "weight", + keras_embedding.value, + ) + + # Decoder blocks + for i in range(keras_nlp_model.backbone.num_layers): + decoder_block = keras_nlp_model.backbone.get_layer(f"decoder_block_{i}") + # Pre-attention norm + update_state_dict( + model.model.layers[i].input_layernorm, + "weight", + decoder_block.pre_attention_norm.weights[0].value, + ) + + # Attention + qkv = ( + decoder_block.attention.query_dense.weights[0].value.transpose( + 1, 2 + ), + decoder_block.attention.key_dense.weights[0].value.transpose(1, 2), + decoder_block.attention.value_dense.weights[0].value.transpose( + 1, 2 + ), + ) + qkv_target_shape = model.model.layers[i].self_attn.qkv_proj.weight.shape + combined_tensor = _reconcile_attention_dims(qkv, qkv_target_shape) + + update_state_dict( + model.model.layers[i].self_attn.qkv_proj, "weight", combined_tensor + ) + + out_target_shape = model.model.layers[i].self_attn.o_proj.weight.shape + keras_out_tensor = decoder_block.attention.output_dense.weights[0].value + out_tensor = keras_out_tensor.reshape( + (out_target_shape[1], out_target_shape[0]) # Transpose target size + ).transpose(0, 1) + + update_state_dict( + model.model.layers[i].self_attn.o_proj, "weight", out_tensor + ) + + # Post-attention norm + update_state_dict( + model.model.layers[i].post_attention_layernorm, + "weight", + decoder_block.pre_ffw_norm.weights[0].value, + ) + + # MLP (Feed-forward) + update_state_dict( + model.model.layers[i].mlp.gate_proj, + "weight", + decoder_block.gating_ffw.weights[0].value.transpose(0, 1), + ) + update_state_dict( + model.model.layers[i].mlp.up_proj, + "weight", + decoder_block.gating_ffw_2.weights[0].value.transpose(0, 1), + ) + update_state_dict( + model.model.layers[i].mlp.down_proj, + "weight", + decoder_block.ffw_linear.weights[0].value.transpose(0, 1), + ) + + # Final norm + update_state_dict( + model.model.norm, + "weight", + keras_nlp_model.backbone.layers[-1].weights[0].value, + ) + + print("\nāœ… Weights converted successfully.") + print(f"\n-> Saving PyTorch model checkpoint to `{output_file}`...") + + # Save model checkpoint + torch.save({"model_state_dict": model.state_dict()}, output_file) + + print( + f"\nāœ… Saving complete. Model checkpoint available at `{output_file}`." + ) + + if preset is not None: + # Tokenizer + print( + f"\n-> Loading KerasNLP Gemma tokenizer with preset `{preset}`..." + ) + keras_nlp_tokenizer = keras_nlp.models.GemmaTokenizer.from_preset( + preset + ) + print("\nāœ… Model loading complete.") + print(f"\n-> Saving tokenizer state to directory `{vocab_dir}`...") + + # Save tokenizer state + os.makedirs(vocab_dir, exist_ok=True) + keras_nlp_tokenizer.save_assets(vocab_dir) + + print( + "\nāœ… Saving complete. Tokenizer state " + f"available at `{vocab_dir}/vocabulary.spm`." + ) + + +def update_state_dict(layer, weight_name: str, tensor: torch.Tensor) -> None: + """Updates the state dict for a weight given a tensor.""" + assert ( + tensor.shape == layer.state_dict()[weight_name].shape + ), f"{tensor.shape} vs {layer.state_dict()[weight_name].shape}" + layer.state_dict()[weight_name].copy_(tensor) + + +def flag_error_handler(): + if not FLAGS.preset and not FLAGS.weights_file: + raise ValueError( + "Please pass either a valid Keras preset to `--preset`" + " or supply a Keras weights file (`.weights.h5`) and model size" + " (`2b` or `7b`) to `--weights_file` and `--size`, respectively." + ) + if FLAGS.weights_file: + if FLAGS.preset: + raise ValueError( + "Both `--preset` and `--weights_file` flags cannot be supplied " + "at the same time. Either supply a valid Keras preset to " + "`--preset`or supply a Keras `.weights.h5` file and " + "model size (`2b` or `7b`) to `--weights_file` and `--size`, " + "respectively." + ) + if not str(FLAGS.weights_file).endswith(".weights.h5"): + raise ValueError( + "Please pass a valid Keras weights file ending in `.weights.h5`." + ) + if not FLAGS.size: + raise ValueError( + "The `size` flag must be passed if a weights file is passed. " + "Please pass the appropriate size (`2b` or `7b`) for your " + "model to the `--size` flag." + ) + if FLAGS.size.lower() not in ["2b", "7b"]: + raise ValueError( + "Invalid `size`. Please pass the appropriate size (`2b` or `7b`) " + "for your model to the `--size` flag." + ) + if FLAGS.dtype: + dtype = getattr(torch, FLAGS.dtype) + if not isinstance(dtype, torch.dtype): + raise ValueError( + "Invalid `dtype`. Please pass a valid PyTorch data type (e.g. " + "`float32', 'float16`, etc.) to the `--dtype` flag." + ) + + +def main(_): + flag_error_handler() + with _set_default_tensor_type(getattr(torch, FLAGS.dtype)): + convert_checkpoints( + FLAGS.preset, + FLAGS.weights_file, + FLAGS.size, + FLAGS.output_file, + FLAGS.vocab_dir, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/tools/gemma/run_gemma_xla.py b/tools/gemma/run_gemma_xla.py new file mode 100644 index 0000000000..9fa50cbd2b --- /dev/null +++ b/tools/gemma/run_gemma_xla.py @@ -0,0 +1,287 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import os +import random +import sys +from typing import List + +import gemma.xla_model_parallel as xla_model_parallel +import numpy as np +import torch +import torch.multiprocessing +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +from absl import app +from absl import flags +from gemma.config import GemmaConfig +from gemma.config import get_config_for_2b +from gemma.config import get_config_for_7b +from gemma.model_xla import GemmaForCausalLM +from gemma.tokenizer import Tokenizer + +PAD_TOKEN_ID = -1 + +FILE_PATH = "gemma.ckpt" +TOKENIZER_DIR = "gemma_tokenizer" + +PRESET_MAP = { + "gemma_2b_en": get_config_for_2b(), + "gemma_instruct_2b_en": get_config_for_2b(), + "gemma_7b_en": get_config_for_7b(), + "gemma_instruct_7b_en": get_config_for_7b(), +} + +SIZE_MAP = { + "2b": get_config_for_2b(), + "7b": get_config_for_7b(), +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' +) +flags.DEFINE_string( + "size", + None, + "Size of model. Must be passed if `preset` is not passed. " + "This should be either `2b` or `7b`.", +) +flags.DEFINE_string( + "checkpoint_file", + "gemma.ckpt", + "A PyTorch checkpoint file containing the converted weights.", +) +flags.DEFINE_string( + "vocab_file", + "gemma_tokenizer/vocabulary.spm", + "The file containing the vocabulary for the tokenizer.", +) +flags.DEFINE_string( + "prompt", + "The capital of France is", + "A test prompt for verifying functionality of the PyTorch Gemma model.", +) + +# This is a modified version of `run_xla.py` script in the Hex-LLM Gemma repo +# to ensure proper functionality after porting checkpoints from Keras. + + +@contextlib.contextmanager +def _set_default_tensor_type(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(torch.float) + + +def generate( + i: int, + model_config: GemmaConfig, + checkpoint_file: str, + vocab_file: str, + prompts: List[str], + output_lens: List[int], + temperatures: List[float], + top_ps: List[float], + top_ks: List[int], +): + # Set seed from config + seed = model_config.seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + device = xm.xla_device() + xm.set_rng_state(seed, device) + + rank = xla_model_parallel.get_model_parallel_rank() + world_size = xla_model_parallel.get_model_parallel_world_size() + if rank > 0: + sys.stdout = open(os.devnull, "w") + + # Load model with ported weights and place on device + with _set_default_tensor_type(model_config.get_dtype()): + model = GemmaForCausalLM(model_config, world_size, rank, device) + model.load_weights(checkpoint_file) + model = model.to(device).eval() + + # Create tokenizer with saved Keras tokenizer state + tokenizer = Tokenizer(vocab_file) + + prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts] + min_prompt_len = min(len(p) for p in prompt_tokens) + + batch_size = len(prompts) + assert batch_size == len(temperatures) + assert batch_size == len(top_ps) + assert batch_size == len(top_ks) + max_seq_len = max([len(p) + o for p, o in zip(prompt_tokens, output_lens)]) + assert max_seq_len <= model_config.max_position_embeddings + if model_config.num_key_value_heads < world_size: + assert world_size % model_config.num_key_value_heads == 0 + n_local_heads = 1 + else: + assert model_config.num_key_value_heads % world_size == 0 + n_local_heads = model_config.num_key_value_heads // world_size + + # build KV caches + kv_caches = [] + for _ in range(model_config.num_hidden_layers): + k_cache = torch.zeros( + size=( + batch_size, + max_seq_len, + n_local_heads, + model_config.head_dim, + ), + dtype=model_config.get_dtype(), + device=device, + ) + v_cache = torch.zeros( + size=( + batch_size, + max_seq_len, + n_local_heads, + model_config.head_dim, + ), + dtype=model_config.get_dtype(), + device=device, + ) + kv_caches.append((k_cache, v_cache)) + + # prepare inputs + token_ids_tensor = torch.full( + (batch_size, max_seq_len), PAD_TOKEN_ID, dtype=torch.int64 + ) + input_token_ids_tensor = torch.full( + (batch_size, min_prompt_len), PAD_TOKEN_ID, dtype=torch.int64 + ) + for i, p in enumerate(prompt_tokens): + token_ids_tensor[i, : len(p)] = torch.tensor(p) + input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( + p[:min_prompt_len] + ) + token_ids_tensor = token_ids_tensor.to(device) + prompt_mask_tensor = token_ids_tensor != PAD_TOKEN_ID + input_token_ids_tensor = input_token_ids_tensor.to(device) + input_positions_tensor = torch.arange( + 0, min_prompt_len, dtype=torch.int64 + ).to(device) + mask_tensor = torch.full( + (1, 1, max_seq_len, max_seq_len), -2.3819763e38 + ).to(torch.float) + mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device) + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device) + temperatures_tensor = torch.FloatTensor(temperatures).to(device) + top_ps_tensor = torch.FloatTensor(top_ps).to(device) + top_ks_tensor = torch.LongTensor(top_ks).to(device) + output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device) + xm.mark_step() + + # Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output. + for i in range(max_seq_len - min_prompt_len): + next_token_ids = model( + input_token_ids=input_token_ids_tensor, + input_positions=input_positions_tensor, + kv_write_indices=None, + kv_caches=kv_caches, + mask=curr_mask_tensor, + output_positions=output_positions_tensor, + temperatures=temperatures_tensor, + top_ps=top_ps_tensor, + top_ks=top_ks_tensor, + ) + curr_prompt_mask = prompt_mask_tensor.index_select( + 1, output_index + ).squeeze(dim=1) + curr_token_ids = token_ids_tensor.index_select(1, output_index).squeeze( + dim=1 + ) + output_token_ids = torch.where( + curr_prompt_mask, curr_token_ids, next_token_ids + ).unsqueeze(dim=1) + token_ids_tensor.index_copy_(1, output_index, output_token_ids) + + input_token_ids_tensor = output_token_ids + input_positions_tensor = output_index + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device) + output_index = output_index + 1 + xm.mark_step() + + # Detokenization. + token_ids = token_ids_tensor.tolist() + results = [] + for i, tokens in enumerate(token_ids): + trimmed_output = tokens[ + len(prompt_tokens[i]) : len(prompt_tokens[i]) + output_lens[i] + ] + if tokenizer.eos_id in trimmed_output: + eos_index = trimmed_output.index(tokenizer.eos_id) + trimmed_output = trimmed_output[:eos_index] + results.append(tokenizer.decode(trimmed_output)) + + for prompt, result in zip(prompts, results): + print("======================================") + print(f"PROMPT: {prompt}") + print(f"RESULT: {result}") + print("======================================") + + +def flag_error_handler(): + if not FLAGS.preset and not FLAGS.size: + raise ValueError( + "Please pass either a valid Keras preset to `--preset`" + " or supply a model size (`2b` or `7b`) to `--size`." + ) + if FLAGS.size and FLAGS.size.lower() not in ["2b", "7b"]: + raise ValueError( + "Invalid `size`. Please pass the appropriate size (`2b` or `7b`) " + "for your model to the `--size` flag." + ) + + +def main(_): + flag_error_handler() + if FLAGS.preset: + model_config = PRESET_MAP[FLAGS.preset] + else: + model_config = SIZE_MAP[FLAGS.size.lower()] + prompts = [ + FLAGS.prompt, + ] + n = len(prompts) + output_lengths = [10] * n + temperatures = [0.95] * n + top_ps = [1.0] * n + top_ks = [100] * n + xmp.spawn( + generate, + args=( + model_config, + FLAGS.checkpoint_file, + FLAGS.vocab_file, + prompts, + output_lengths, + temperatures, + top_ps, + top_ks, + ), + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/tools/sentencepiece_testing/create_gemma_test_proto.py b/tools/sentencepiece_testing/create_gemma_test_proto.py new file mode 100644 index 0000000000..c3ce418a4b --- /dev/null +++ b/tools/sentencepiece_testing/create_gemma_test_proto.py @@ -0,0 +1,36 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tools.sentencepiece_testing.utils import train_sentencepiece + + +def main(): + train_sentencepiece( + ["the quick brown fox", "the earth is round"], + "gemma_test_vocab.spm", + vocab_size=11, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + + +if __name__ == "__main__": + main() diff --git a/tools/sentencepiece_testing/create_llama_test_proto.py b/tools/sentencepiece_testing/create_llama_test_proto.py new file mode 100644 index 0000000000..c57a0074e2 --- /dev/null +++ b/tools/sentencepiece_testing/create_llama_test_proto.py @@ -0,0 +1,32 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tools.sentencepiece_testing.utils import train_sentencepiece + + +def main(): + train_sentencepiece( + ["the quick brown fox", "the earth is round"], + "llama_test_vocab.spm", + vocab_size=10, + model_type="WORD", + pad_id=-1, + unk_id=0, + bos_id=1, + eos_id=2, + ) + + +if __name__ == "__main__": + main()