From a76a91b87dcacccf6936873b930c1201dfc871a0 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Thu, 18 Jan 2024 23:58:11 +0200 Subject: [PATCH 01/24] Add AlibiBias layer --- keras_nlp/layers/modeling/alibi_bias.py | 150 ++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 keras_nlp/layers/modeling/alibi_bias.py diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py new file mode 100644 index 0000000000..9b7082052e --- /dev/null +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -0,0 +1,150 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops + + +@keras_nlp_export("keras_nlp.layers.AlibiBias") +class AlibiBias(keras.layers.Layer): + """A layer that generates alibi bias + + This layer generates a linear, non-learned bias. Defined and formalized in + [Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation](https://arxiv.org/abs/2108.12409). + + Takes as input an embedded token tensor. The input must have shape + `(batch_size, sequence_length, hidden_dim)`. This layer will return an alibi + bias of the shape `(1, num_heads, 1, sequence_length)`, which will be added to + the result of the query-key dot product in the multi-head attention layer of + the transformer. + + Args: + num_heads: int. The number of heads in the multi-head attention layer of + the transformer. + alibi_bias_max: int. This value will be used to compute the slope of + each head. The heads slopes is a geometric sequence that starts at + `2**(-alibi_bias_max/num_heads)` and uses that same value as its + ratio. Defaults to 8. + full: bool. Whether to return the full alibi bias tensor. If set to + `True`, the alibi bias shape will be + `(1, num_heads, sequence_length, sequence_length)`. Defaults to + `False`, so the third dimension will be broadcasted, and this will + work because of the translation invariance property of the softmax, + let `L` be a tensor and `x` a constant, `softmax(L+x) = softmax(L)` + batched: bool. Whether to return the alibi bias tensor with first + dimention equal to `batch_size`. If set to `True` the alibi bias + shape wil be `(batch_size, num_heads, 1, sequence_length)`. Defaults + to `False`, so the first dimension will be broadcasted. + Call arguments: + inputs: The tensor inputs to compute an embedding for, with shape + `(batch_size, sequence_length, hidden_dim)`. + + Examples: + ```python + # create a simple embedding layer with sinusoidal positional encoding + seq_len = 100 + vocab_size = 1000 + embedding_dim = 32 + inputs = keras.Input((seq_len,), dtype="float32") + embedding = keras.layers.Embedding( + input_dim=vocab_size, output_dim=embedding_dim + )(inputs) + positional_encoding = keras_nlp.layers.SinePositionEncoding()(embedding) + outputs = embedding + positional_encoding + ``` + + References: + - [Press et al., 2021](https://arxiv.org/abs/2108.12409) + """ + + def __init__( + self, + num_heads, + alibi_bias_max=8, + full=False, + batched=False, + **kwargs, + ): + super().__init__(**kwargs) + self.num_heads = num_heads + self.alibi_bias_max = alibi_bias_max + self.full = full + self.batched = batched + + def call(self, inputs): + shape = ops.shape(inputs) + batch_size = shape[0] + seq_length = shape[1] + + slopes = ops.convert_to_tensor(self._get_slopes(), dtype=float) + slopes = ops.reshape(slopes, (self.num_heads, 1, 1)) + + + sequence_range = ops.expand_dims(ops.arange( 1 - seq_length, 1, dtype=float), 0) + if self.full: + sequence_range = ops.subtract(sequence_range, ops.expand_dims(ops.arange( 1 - seq_length, 1, dtype=float), 1)) + sequence_range = ops.multiply(ops.abs(sequence_range), -1) + + alibi_bias = slopes * ops.expand_dims(ops.arange(seq_length, dtype=float), 0) + alibi_bias = ops.expand_dims(alibi_bias, 0) + if self.batched: + return ops.repeat(alibi_bias, batch_size, axis=0) + + return alibi_bias + + def _get_slopes(self): + # this function is adopted from Alibi original implementation + # https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + def get_slopes_power_of_2(n): + start = 2 ** ( + -(2 ** -(math.log2(n) - math.log2(self.alibi_bias_max))) + ) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(self.num_heads).is_integer(): + return get_slopes_power_of_2(self.num_heads) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(self.num_heads)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + self._get_slopes(2 * closest_power_of_2)[0::2][ + : self.num_heads - closest_power_of_2 + ] + ) + + def compute_output_shape(self, input_shape): + batch_size = input_shape[0] + seq_length = input_shape[1] + output_shape = [1, self.num_heads, 1, seq_length] + if self.full: + output_shape[2] = seq_length + if self.batched: + output_shape[0] = batch_size + + return tuple(output_shape) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "alibi_bias_max": self.alibi_bias_max, + "full": self.full, + "batched": self.batched, + } + ) + return config \ No newline at end of file From 3f3d9686b0fcd4426a5d8ad92ffd9ebb20c2b66f Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Fri, 19 Jan 2024 00:05:16 +0200 Subject: [PATCH 02/24] Add example --- keras_nlp/layers/modeling/alibi_bias.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index 9b7082052e..62638033f2 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -54,7 +54,8 @@ class AlibiBias(keras.layers.Layer): Examples: ```python - # create a simple embedding layer with sinusoidal positional encoding + # create a simple layer that takes token embeddings as input and generates + # the alibi tensor seq_len = 100 vocab_size = 1000 embedding_dim = 32 @@ -62,8 +63,7 @@ class AlibiBias(keras.layers.Layer): embedding = keras.layers.Embedding( input_dim=vocab_size, output_dim=embedding_dim )(inputs) - positional_encoding = keras_nlp.layers.SinePositionEncoding()(embedding) - outputs = embedding + positional_encoding + alibi_bias = keras_nlp.layers.AlibiBias(num_heads=8)(embedding) ``` References: From f1536df2fca9d999d0a6cb9997c2dc04876b577f Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Mon, 22 Jan 2024 18:22:56 +0200 Subject: [PATCH 03/24] Convert layer to recieve attn_scores and add the alibi bias to it. --- keras_nlp/layers/modeling/alibi_bias.py | 105 +++++++++--------------- 1 file changed, 41 insertions(+), 64 deletions(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index 62638033f2..967c224929 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -20,41 +20,29 @@ @keras_nlp_export("keras_nlp.layers.AlibiBias") class AlibiBias(keras.layers.Layer): - """A layer that generates alibi bias + """A layer that add the alibi bias to attention scores - This layer generates a linear, non-learned bias. Defined and formalized in + This layer generates a linear, non-learned bias. Defined and formalized in [Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation](https://arxiv.org/abs/2108.12409). - Takes as input an embedded token tensor. The input must have shape - `(batch_size, sequence_length, hidden_dim)`. This layer will return an alibi - bias of the shape `(1, num_heads, 1, sequence_length)`, which will be added to - the result of the query-key dot product in the multi-head attention layer of - the transformer. + Takes as input an attention score. The input must have shape + `(batch_size, num_heads, query_length, key_length)`. This layer will return + an the attention scores after adding the alibi bias which will have the same + shape as the input. Args: - num_heads: int. The number of heads in the multi-head attention layer of - the transformer. - alibi_bias_max: int. This value will be used to compute the slope of - each head. The heads slopes is a geometric sequence that starts at - `2**(-alibi_bias_max/num_heads)` and uses that same value as its + alibi_bias_max: int. This value will be used to compute the slope of + each head. The heads slopes is a geometric sequence that starts at + `2**(-alibi_bias_max/num_heads)` and uses that same value as its ratio. Defaults to 8. - full: bool. Whether to return the full alibi bias tensor. If set to - `True`, the alibi bias shape will be - `(1, num_heads, sequence_length, sequence_length)`. Defaults to - `False`, so the third dimension will be broadcasted, and this will - work because of the translation invariance property of the softmax, - let `L` be a tensor and `x` a constant, `softmax(L+x) = softmax(L)` - batched: bool. Whether to return the alibi bias tensor with first - dimention equal to `batch_size`. If set to `True` the alibi bias - shape wil be `(batch_size, num_heads, 1, sequence_length)`. Defaults - to `False`, so the first dimension will be broadcasted. Call arguments: - inputs: The tensor inputs to compute an embedding for, with shape - `(batch_size, sequence_length, hidden_dim)`. + attention_scores: The result of multipying the query and the key of the + multi head attention of the transformer. with shape + `(batch_size, num_heads, query_length, key_length)`. Examples: ```python - # create a simple layer that takes token embeddings as input and generates + # create a simple layer that takes token embeddings as input and generates # the alibi tensor seq_len = 100 vocab_size = 1000 @@ -72,40 +60,40 @@ class AlibiBias(keras.layers.Layer): def __init__( self, - num_heads, alibi_bias_max=8, - full=False, - batched=False, **kwargs, ): super().__init__(**kwargs) - self.num_heads = num_heads self.alibi_bias_max = alibi_bias_max - self.full = full - self.batched = batched def call(self, inputs): shape = ops.shape(inputs) - batch_size = shape[0] - seq_length = shape[1] + if ( len(shape) != 4): + raise ValueError("Expected inputs of shape (batch_size, num_heads, " + f"query_length, key_length) but recieved inputs of shape {shape}") + + num_heads = shape[1] + seq_length = shape[-1] + alibi_bias = self._get_alibi_bias(num_heads, seq_length) - slopes = ops.convert_to_tensor(self._get_slopes(), dtype=float) - slopes = ops.reshape(slopes, (self.num_heads, 1, 1)) + return ops.add(inputs, alibi_bias) + def _get_alibi_bias(self, num_heads, seq_length): + slopes = ops.convert_to_tensor(self._get_slopes(num_heads), dtype=float) + slopes = ops.expand_dims(slopes, 1) - sequence_range = ops.expand_dims(ops.arange( 1 - seq_length, 1, dtype=float), 0) - if self.full: - sequence_range = ops.subtract(sequence_range, ops.expand_dims(ops.arange( 1 - seq_length, 1, dtype=float), 1)) - sequence_range = ops.multiply(ops.abs(sequence_range), -1) - - alibi_bias = slopes * ops.expand_dims(ops.arange(seq_length, dtype=float), 0) + seq_range = ops.expand_dims( + ops.arange(1.0 - seq_length, 1.0, dtype=float), 0 + ) + alibi_bias = ops.multiply(slopes, seq_range) + alibi_bias = ops.expand_dims(alibi_bias, 1) alibi_bias = ops.expand_dims(alibi_bias, 0) - if self.batched: - return ops.repeat(alibi_bias, batch_size, axis=0) - return alibi_bias - - def _get_slopes(self): + return ops.convert_to_tensor(alibi_bias, dtype=self.compute_dtype) + + + + def _get_slopes(self, num_heads): # this function is adopted from Alibi original implementation # https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 def get_slopes_power_of_2(n): @@ -115,36 +103,25 @@ def get_slopes_power_of_2(n): ratio = start return [start * ratio**i for i in range(n)] - if math.log2(self.num_heads).is_integer(): - return get_slopes_power_of_2(self.num_heads) + if math.log2(num_heads).is_integer(): + return get_slopes_power_of_2(num_heads) else: - closest_power_of_2 = 2 ** math.floor(math.log2(self.num_heads)) + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) return ( get_slopes_power_of_2(closest_power_of_2) + self._get_slopes(2 * closest_power_of_2)[0::2][ - : self.num_heads - closest_power_of_2 + : num_heads - closest_power_of_2 ] ) def compute_output_shape(self, input_shape): - batch_size = input_shape[0] - seq_length = input_shape[1] - output_shape = [1, self.num_heads, 1, seq_length] - if self.full: - output_shape[2] = seq_length - if self.batched: - output_shape[0] = batch_size - - return tuple(output_shape) - + return input_shape + def get_config(self): config = super().get_config() config.update( { - "num_heads": self.num_heads, "alibi_bias_max": self.alibi_bias_max, - "full": self.full, - "batched": self.batched, } ) - return config \ No newline at end of file + return config From 9c891ae245d73dfc6542ea84619bbe136185ae16 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Mon, 22 Jan 2024 21:23:19 +0200 Subject: [PATCH 04/24] Change layer logic --- keras_nlp/layers/modeling/alibi_bias.py | 60 ++++++++++++++----------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index 967c224929..d62613b3cd 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -25,9 +25,8 @@ class AlibiBias(keras.layers.Layer): This layer generates a linear, non-learned bias. Defined and formalized in [Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation](https://arxiv.org/abs/2108.12409). - Takes as input an attention score. The input must have shape - `(batch_size, num_heads, query_length, key_length)`. This layer will return - an the attention scores after adding the alibi bias which will have the same + Takes as input an attention score. This layer will return the attention + scores after adding the alibi bias to it. The output will have the same shape as the input. Args: @@ -37,8 +36,9 @@ class AlibiBias(keras.layers.Layer): ratio. Defaults to 8. Call arguments: attention_scores: The result of multipying the query and the key of the - multi head attention of the transformer. with shape - `(batch_size, num_heads, query_length, key_length)`. + multi head attention of the transformer. The shape must be greater + than or equal to 3 with the last 3 dimensions equal to + `(num_heads, query_length, key_length)`. Examples: ```python @@ -66,32 +66,42 @@ def __init__( super().__init__(**kwargs) self.alibi_bias_max = alibi_bias_max - def call(self, inputs): - shape = ops.shape(inputs) - if ( len(shape) != 4): - raise ValueError("Expected inputs of shape (batch_size, num_heads, " - f"query_length, key_length) but recieved inputs of shape {shape}") - - num_heads = shape[1] - seq_length = shape[-1] - alibi_bias = self._get_alibi_bias(num_heads, seq_length) + def call(self, attention_scores): + shape = ops.shape(attention_scores) + print(shape) + if len(shape) < 3: + raise ValueError( + "Expected `attention_scores` shape to be " + "`(..., num_heads, query_length, key_Length)`." + f" Recived shape={shape}" + ) - return ops.add(inputs, alibi_bias) + key_length = shape[-1] + num_heads = shape[-3] - def _get_alibi_bias(self, num_heads, seq_length): - slopes = ops.convert_to_tensor(self._get_slopes(num_heads), dtype=float) - slopes = ops.expand_dims(slopes, 1) + alibi_bias = self._get_alibi_bias(num_heads, key_length) + alibi_bias = ops.reshape( + alibi_bias, + tuple([1 for _ in range(len(shape[:-3]))]) + + (num_heads, 1, key_length), + ) + + return ops.add(attention_scores, alibi_bias) - seq_range = ops.expand_dims( - ops.arange(1.0 - seq_length, 1.0, dtype=float), 0 + def _get_alibi_bias(self, num_heads, key_length): + slopes = ops.convert_to_tensor( + self._get_slopes(num_heads), dtype=self.compute_dtype ) - alibi_bias = ops.multiply(slopes, seq_range) - alibi_bias = ops.expand_dims(alibi_bias, 1) - alibi_bias = ops.expand_dims(alibi_bias, 0) + slopes = ops.expand_dims(slopes, 1) - return ops.convert_to_tensor(alibi_bias, dtype=self.compute_dtype) + seq_range = ops.expand_dims(ops.arange(1.0 - key_length, 1.0), 0) + seq_range = ops.cast(seq_range, dtype=self.compute_dtype) + + alibi_bias = ops.multiply(slopes, seq_range) - + # Expand on query dimension + # return shape is `(num_heads, 1, key_length)` + return ops.expand_dims(alibi_bias, 1) def _get_slopes(self, num_heads): # this function is adopted from Alibi original implementation From 367647618218b33c83da29d8a911ea4602604583 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Mon, 22 Jan 2024 21:25:48 +0200 Subject: [PATCH 05/24] Add layer test --- keras_nlp/layers/modeling/alibi_bias_test.py | 208 +++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 keras_nlp/layers/modeling/alibi_bias_test.py diff --git a/keras_nlp/layers/modeling/alibi_bias_test.py b/keras_nlp/layers/modeling/alibi_bias_test.py new file mode 100644 index 0000000000..900ddcda04 --- /dev/null +++ b/keras_nlp/layers/modeling/alibi_bias_test.py @@ -0,0 +1,208 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.backend import random +from keras_nlp.layers.modeling.alibi_bias import AlibiBias +from keras_nlp.tests.test_case import TestCase + + +class AlibiBiasTest(TestCase): + def test_layer_behaviors(self): + alibi_bias_max = 8 + batch_size = 4 + num_heads = 8 + query_length = 10 + key_length = 10 + self.run_layer_test( + cls=AlibiBias, + init_kwargs={ + "alibi_bias_max": alibi_bias_max, + }, + input_data=random.uniform( + shape=(batch_size, num_heads, query_length, key_length) + ), + expected_output_shape=( + batch_size, + num_heads, + query_length, + key_length, + ), + ) + + def test_float16_dtype(self): + # Create a 4-dimensional input (the first dimension is implicit). + alibi_bias_max = 8 + num_heads = 8 + query_length = 5 + key_length = 10 + test_layer = AlibiBias(alibi_bias_max=alibi_bias_max, dtype="float16") + input_tensor = keras.Input(shape=(num_heads, query_length, key_length)) + output_tensor = test_layer(input_tensor) + + # the output is expected to be the same as the input shape in all + # dimensions. here, the first dimension is implicit and is for batch + expected_output_shape = (None, num_heads, query_length, key_length) + self.assertEqual(expected_output_shape, output_tensor.shape) + # The default output dtype for this layer should be "float32". + self.assertEqual("float16", output_tensor.dtype) + + def test_dynamic_layer_output_shape(self): + query_length = 10 + key_length = 10 + num_heads = 4 + + test_layer = AlibiBias() + # Create a 4-dimensional input (the first dimension is implicit). + input_tensor = keras.Input( + shape=(None, num_heads, query_length, key_length) + ) + output_tensor = test_layer(input_tensor) + + # the output is expected to be the same as the input shape in all + # dimensions - but may be None if the input shape is None there. + expected_output_shape = ( + None, + None, + num_heads, + query_length, + key_length, + ) + self.assertEqual(expected_output_shape, output_tensor.shape) + + def test_value_error_when_inputs_shape_is_less_than_3(self): + with self.assertRaises(ValueError): + AlibiBias()(random.uniform(shape=(12, 12))) + + def test_num_heads_is_not_power_of_two(self): + inputs_shape = (12, 12, 12) + inputs = random.uniform(shape=inputs_shape) + layer = AlibiBias() + outputs = layer(inputs) + self.assertEqual(inputs_shape, outputs.shape) + + def test_correct_output(self): + batch_size = 1 + num_heads = 8 + query_length = 1 + key_length = 3 + input_shape = (batch_size, num_heads, query_length, key_length) + input_tensor = ops.zeros(input_shape) + layer = AlibiBias() + output_tensor = layer(input_tensor) + print(output_tensor) + self.assertAllClose( + output_tensor, + ops.convert_to_tensor( + [ + [ + [[-1.0, -0.5, 0.0]], + [[-0.5, -0.25, 0.0]], + [[-0.25, -0.125, 0.0]], + [[-0.125, -0.0625, 0.0]], + [[-0.0625, -0.03125, 0.0]], + [[-0.03125, -0.015625, 0.0]], + [[-0.015625, -0.0078125, 0.0]], + [[-0.0078125, -0.00390625, 0.0]], + ] + ] + ), + ) + + def test_correct_output_num_heads_not_power_of_two(self): + batch_size = 1 + num_heads = 14 + query_length = 1 + key_length = 3 + input_shape = (batch_size, num_heads, query_length, key_length) + input_tensor = ops.zeros(input_shape) + layer = AlibiBias() + output_tensor = layer(input_tensor) + print(output_tensor) + self.assertAllClose( + output_tensor, + ops.convert_to_tensor( + [ + [ + [[-1.0, -0.5, 0.0]], + [[-0.5, -0.25, 0.0]], + [[-0.25, -0.125, 0.0]], + [[-0.125, -0.0625, 0.0]], + [[-0.0625, -0.03125, 0.0]], + [[-0.03125, -0.015625, 0.0]], + [[-0.015625, -0.0078125, 0.0]], + [[-0.0078125, -0.00390625, 0.0]], + [[-1.4142135, -0.70710677, 0.0]], + [[-0.70710677, -0.35355338, 0.0]], + [[-0.35355338, -0.17677669, 0.0]], + [[-0.17677669, -0.08838835, 0.0]], + [[-0.08838835, -0.04419417, 0.0]], + [[-0.04419417, -0.02209709, 0.0]], + ] + ] + ), + ) + + def test_correct_output_alibi_bias_max(self): + alibi_bias_max = 12 + batch_size = 1 + num_heads = 2 + query_length = 1 + key_length = 3 + input_shape = (batch_size, num_heads, query_length, key_length) + input_tensor = ops.zeros(input_shape) + layer = AlibiBias(alibi_bias_max=alibi_bias_max) + output_tensor = layer(input_tensor) + print(output_tensor) + self.assertAllClose( + output_tensor, + ops.convert_to_tensor( + [ + [ + [[-0.03125, -0.015625, 0.0]], + [[-0.00048828, -0.00024414, 0.0]], + ] + ] + ), + ) + + def test_correct_output_alibi_bias_max_num_heads_not_power_of_two( + self, + ): + alibi_bias_max = 6 + batch_size = 1 + num_heads = 3 + query_length = 1 + key_length = 3 + input_shape = (batch_size, num_heads, query_length, key_length) + input_tensor = ops.zeros(input_shape) + layer = AlibiBias(alibi_bias_max=alibi_bias_max) + output_tensor = layer(input_tensor) + print(output_tensor) + self.assertAllClose( + output_tensor, + ops.convert_to_tensor( + [ + [ + [[-0.25, -0.125, 0.0]], + [[-0.03125, -0.015625, 0.0]], + [[-0.70710677, -0.35355338, 0.0]], + + ] + ] + ), + ) From 358e6f0472ee25b3dfea08bca072d7e41bbc7bef Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Mon, 22 Jan 2024 21:29:30 +0200 Subject: [PATCH 06/24] Format the code --- keras_nlp/layers/modeling/alibi_bias_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/keras_nlp/layers/modeling/alibi_bias_test.py b/keras_nlp/layers/modeling/alibi_bias_test.py index 900ddcda04..b66ce18516 100644 --- a/keras_nlp/layers/modeling/alibi_bias_test.py +++ b/keras_nlp/layers/modeling/alibi_bias_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np - from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.backend import random @@ -201,7 +199,6 @@ def test_correct_output_alibi_bias_max_num_heads_not_power_of_two( [[-0.25, -0.125, 0.0]], [[-0.03125, -0.015625, 0.0]], [[-0.70710677, -0.35355338, 0.0]], - ] ] ), From 0c949c4ea6076d8f7da4f4355fc04c8bf937efdf Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Mon, 22 Jan 2024 21:43:32 +0200 Subject: [PATCH 07/24] Fix seq_range creation to be int range --- keras_nlp/layers/modeling/alibi_bias.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index d62613b3cd..666b467ccc 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -94,7 +94,7 @@ def _get_alibi_bias(self, num_heads, key_length): ) slopes = ops.expand_dims(slopes, 1) - seq_range = ops.expand_dims(ops.arange(1.0 - key_length, 1.0), 0) + seq_range = ops.expand_dims(ops.arange(1 - key_length, 1), 0) seq_range = ops.cast(seq_range, dtype=self.compute_dtype) alibi_bias = ops.multiply(slopes, seq_range) From 5f1ebc111822be2aac90122edc89df04c11b7f89 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Mon, 22 Jan 2024 21:44:29 +0200 Subject: [PATCH 08/24] Change bloom model to use alibi bias --- keras_nlp/models/bloom/bloom_attention.py | 51 ++++------------------- 1 file changed, 9 insertions(+), 42 deletions(-) diff --git a/keras_nlp/models/bloom/bloom_attention.py b/keras_nlp/models/bloom/bloom_attention.py index 7af2e7a34d..e26810a72e 100644 --- a/keras_nlp/models/bloom/bloom_attention.py +++ b/keras_nlp/models/bloom/bloom_attention.py @@ -15,6 +15,7 @@ from keras_nlp.backend import keras from keras_nlp.backend import ops +from keras_nlp.layers.modeling.alibi_bias import AlibiBias from keras_nlp.utils.keras_utils import clone_initializer @@ -74,6 +75,8 @@ def build(self, inputs_shape): ) self._value_dense.build(inputs_shape) + self._alibi_layer = AlibiBias() + self._output_dense = keras.layers.Dense( hidden_dim, kernel_initializer=clone_initializer(self.kernel_initializer), @@ -92,37 +95,6 @@ def build(self, inputs_shape): self.built = True - @staticmethod - def _build_alibi_tensor(num_heads, seq_length, alibi_bias_max=8): - # this function is adopted from fairseq - # https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 - def get_slopes(n): - def get_slopes_power_of_2(n): - start = 2 ** ( - -(2 ** -(math.log2(n) - math.log2(alibi_bias_max))) - ) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + get_slopes(2 * closest_power_of_2)[0::2][ - : n - closest_power_of_2 - ] - ) - - slopes = ops.convert_to_tensor(get_slopes(num_heads), dtype=float) - slopes = ops.expand_dims(slopes, 1) - - alibi = slopes * ops.expand_dims(ops.arange(seq_length, dtype=float), 0) - alibi = ops.expand_dims(alibi, 1) - alibi = ops.expand_dims(alibi, 0) - - return alibi def call( self, @@ -163,20 +135,15 @@ def call( # key (batch_size, num_heads, head_dim, kv_length) key = ops.transpose(key, [0, 2, 3, 1]) - alibi = self._build_alibi_tensor( - num_heads=self.num_heads, seq_length=seq_length - ) - - scores = ( - ops.matmul(query, key) * self.inv_norm_factor + alibi + attention_scores = ( + ops.matmul(query, key) * self.inv_norm_factor ) # [batch_size, num_heads, query_length, kv_length] - - scores = self._softmax(scores, ops.expand_dims(attention_mask, 1)) - - scores = self._dropout_layer(scores) + attention_scores = self._alibi_layer(attention_scores) + attention_scores = self._softmax(attention_scores, ops.expand_dims(attention_mask, 1)) + attention_scores = self._dropout_layer(attention_scores) attention_output = ops.matmul( - scores, value + attention_scores, value ) # [batch_size, num_heads, query_length, head_dim] attention_output = ops.transpose( From aa3b15ff602457b6aae6be62b7eb39d2c739c518 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Mon, 22 Jan 2024 21:45:11 +0200 Subject: [PATCH 09/24] Format the code --- keras_nlp/models/bloom/bloom_attention.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/keras_nlp/models/bloom/bloom_attention.py b/keras_nlp/models/bloom/bloom_attention.py index e26810a72e..366169d1fb 100644 --- a/keras_nlp/models/bloom/bloom_attention.py +++ b/keras_nlp/models/bloom/bloom_attention.py @@ -95,7 +95,6 @@ def build(self, inputs_shape): self.built = True - def call( self, hidden_states, @@ -136,10 +135,12 @@ def call( key = ops.transpose(key, [0, 2, 3, 1]) attention_scores = ( - ops.matmul(query, key) * self.inv_norm_factor + ops.matmul(query, key) * self.inv_norm_factor ) # [batch_size, num_heads, query_length, kv_length] attention_scores = self._alibi_layer(attention_scores) - attention_scores = self._softmax(attention_scores, ops.expand_dims(attention_mask, 1)) + attention_scores = self._softmax( + attention_scores, ops.expand_dims(attention_mask, 1) + ) attention_scores = self._dropout_layer(attention_scores) attention_output = ops.matmul( From f462ee296cb4a168bcc9697c8940bdac07d57d5e Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Tue, 23 Jan 2024 21:44:55 +0200 Subject: [PATCH 10/24] Remove print function --- keras_nlp/layers/modeling/alibi_bias.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index 666b467ccc..a11422da3f 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -68,7 +68,6 @@ def __init__( def call(self, attention_scores): shape = ops.shape(attention_scores) - print(shape) if len(shape) < 3: raise ValueError( "Expected `attention_scores` shape to be " From 726946d428d8c6e73e11150e12e7d890ad72d278 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Wed, 24 Jan 2024 01:08:05 +0200 Subject: [PATCH 11/24] Change logic to only compute alibi bias once --- keras_nlp/layers/modeling/alibi_bias.py | 42 +++++++--- keras_nlp/layers/modeling/alibi_bias_test.py | 85 +++++++++++++++++--- 2 files changed, 108 insertions(+), 19 deletions(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index a11422da3f..8d5c1d67a7 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -60,33 +60,54 @@ class AlibiBias(keras.layers.Layer): def __init__( self, + max_sequence_length, alibi_bias_max=8, **kwargs, ): super().__init__(**kwargs) + self.max_sequence_length = max_sequence_length self.alibi_bias_max = alibi_bias_max + def build(self, inputs_shape): + if len(inputs_shape) < 3: + raise ValueError( + "Expected `attention_scores` shape to be " + "`(..., num_heads, query_length, key_length)`." + f" Received shape={inputs_shape}" + ) + num_heads = inputs_shape[-3] + alibi_bias_shape = tuple([1 for _ in range(len(inputs_shape[:-3]))]) + ( + num_heads, + 1, + self.max_sequence_length, + ) + self.alibi_bias = self.add_weight( + shape=alibi_bias_shape, trainable=False + ) + alibi_bias = self._get_alibi_bias(num_heads, self.max_sequence_length) + alibi_bias = ops.reshape( + alibi_bias, + alibi_bias_shape, + ) + + self.alibi_bias.assign(alibi_bias) + def call(self, attention_scores): shape = ops.shape(attention_scores) if len(shape) < 3: raise ValueError( "Expected `attention_scores` shape to be " - "`(..., num_heads, query_length, key_Length)`." - f" Recived shape={shape}" + "`(..., num_heads, query_length, key_length)`." + f" Received shape={shape}" ) key_length = shape[-1] - num_heads = shape[-3] - alibi_bias = self._get_alibi_bias(num_heads, key_length) - alibi_bias = ops.reshape( - alibi_bias, - tuple([1 for _ in range(len(shape[:-3]))]) - + (num_heads, 1, key_length), + return ops.add( + attention_scores, + self.alibi_bias[..., self.max_sequence_length - key_length :], ) - return ops.add(attention_scores, alibi_bias) - def _get_alibi_bias(self, num_heads, key_length): slopes = ops.convert_to_tensor( self._get_slopes(num_heads), dtype=self.compute_dtype @@ -130,6 +151,7 @@ def get_config(self): config = super().get_config() config.update( { + "max_sequence_length": self.max_sequence_length, "alibi_bias_max": self.alibi_bias_max, } ) diff --git a/keras_nlp/layers/modeling/alibi_bias_test.py b/keras_nlp/layers/modeling/alibi_bias_test.py index b66ce18516..9cf6d5099c 100644 --- a/keras_nlp/layers/modeling/alibi_bias_test.py +++ b/keras_nlp/layers/modeling/alibi_bias_test.py @@ -24,11 +24,13 @@ def test_layer_behaviors(self): alibi_bias_max = 8 batch_size = 4 num_heads = 8 + max_sequence_length = 10 query_length = 10 key_length = 10 self.run_layer_test( cls=AlibiBias, init_kwargs={ + "max_sequence_length": max_sequence_length, "alibi_bias_max": alibi_bias_max, }, input_data=random.uniform( @@ -40,6 +42,8 @@ def test_layer_behaviors(self): query_length, key_length, ), + expected_num_non_trainable_weights=1, + expected_num_non_trainable_variables=1, ) def test_float16_dtype(self): @@ -47,8 +51,13 @@ def test_float16_dtype(self): alibi_bias_max = 8 num_heads = 8 query_length = 5 + max_sequence_length = 10 key_length = 10 - test_layer = AlibiBias(alibi_bias_max=alibi_bias_max, dtype="float16") + test_layer = AlibiBias( + max_sequence_length=max_sequence_length, + alibi_bias_max=alibi_bias_max, + dtype="float16", + ) input_tensor = keras.Input(shape=(num_heads, query_length, key_length)) output_tensor = test_layer(input_tensor) @@ -60,11 +69,12 @@ def test_float16_dtype(self): self.assertEqual("float16", output_tensor.dtype) def test_dynamic_layer_output_shape(self): + max_sequence_length = 10 query_length = 10 key_length = 10 num_heads = 4 - test_layer = AlibiBias() + test_layer = AlibiBias(max_sequence_length=max_sequence_length) # Create a 4-dimensional input (the first dimension is implicit). input_tensor = keras.Input( shape=(None, num_heads, query_length, key_length) @@ -83,13 +93,25 @@ def test_dynamic_layer_output_shape(self): self.assertEqual(expected_output_shape, output_tensor.shape) def test_value_error_when_inputs_shape_is_less_than_3(self): + max_sequence_length = 12 with self.assertRaises(ValueError): - AlibiBias()(random.uniform(shape=(12, 12))) + AlibiBias(max_sequence_length=max_sequence_length)( + random.uniform(shape=(12, 12)) + ) + + def test_key_length_is_less_than_max_sequence_length(self): + max_sequence_length = 20 + inputs_shape = (5, 4, 4, 12, 12) + inputs = random.uniform(shape=inputs_shape) + layer = AlibiBias(max_sequence_length=max_sequence_length) + outputs = layer(inputs) + self.assertEqual(inputs_shape, outputs.shape) def test_num_heads_is_not_power_of_two(self): - inputs_shape = (12, 12, 12) + max_sequence_length = 12 + inputs_shape = (1, 12, 12) inputs = random.uniform(shape=inputs_shape) - layer = AlibiBias() + layer = AlibiBias(max_sequence_length=max_sequence_length) outputs = layer(inputs) self.assertEqual(inputs_shape, outputs.shape) @@ -98,9 +120,12 @@ def test_correct_output(self): num_heads = 8 query_length = 1 key_length = 3 + max_sequence_length = 3 input_shape = (batch_size, num_heads, query_length, key_length) input_tensor = ops.zeros(input_shape) - layer = AlibiBias() + layer = AlibiBias( + max_sequence_length=max_sequence_length, + ) output_tensor = layer(input_tensor) print(output_tensor) self.assertAllClose( @@ -126,9 +151,12 @@ def test_correct_output_num_heads_not_power_of_two(self): num_heads = 14 query_length = 1 key_length = 3 + max_sequence_length = 3 input_shape = (batch_size, num_heads, query_length, key_length) input_tensor = ops.zeros(input_shape) - layer = AlibiBias() + layer = AlibiBias( + max_sequence_length=max_sequence_length, + ) output_tensor = layer(input_tensor) print(output_tensor) self.assertAllClose( @@ -161,9 +189,13 @@ def test_correct_output_alibi_bias_max(self): num_heads = 2 query_length = 1 key_length = 3 + max_sequence_length = 3 input_shape = (batch_size, num_heads, query_length, key_length) input_tensor = ops.zeros(input_shape) - layer = AlibiBias(alibi_bias_max=alibi_bias_max) + layer = AlibiBias( + max_sequence_length=max_sequence_length, + alibi_bias_max=alibi_bias_max, + ) output_tensor = layer(input_tensor) print(output_tensor) self.assertAllClose( @@ -186,9 +218,13 @@ def test_correct_output_alibi_bias_max_num_heads_not_power_of_two( num_heads = 3 query_length = 1 key_length = 3 + max_sequence_length = 3 input_shape = (batch_size, num_heads, query_length, key_length) input_tensor = ops.zeros(input_shape) - layer = AlibiBias(alibi_bias_max=alibi_bias_max) + layer = AlibiBias( + max_sequence_length=max_sequence_length, + alibi_bias_max=alibi_bias_max, + ) output_tensor = layer(input_tensor) print(output_tensor) self.assertAllClose( @@ -203,3 +239,34 @@ def test_correct_output_alibi_bias_max_num_heads_not_power_of_two( ] ), ) + + def test_correct_output_key_length_smaller_than_max_sequence_length(self): + batch_size = 1 + num_heads = 8 + query_length = 1 + key_length = 3 + max_sequence_length = 10 + input_shape = (batch_size, num_heads, query_length, key_length) + input_tensor = ops.zeros(input_shape) + layer = AlibiBias( + max_sequence_length=max_sequence_length, + ) + output_tensor = layer(input_tensor) + print(output_tensor) + self.assertAllClose( + output_tensor, + ops.convert_to_tensor( + [ + [ + [[-1.0, -0.5, 0.0]], + [[-0.5, -0.25, 0.0]], + [[-0.25, -0.125, 0.0]], + [[-0.125, -0.0625, 0.0]], + [[-0.0625, -0.03125, 0.0]], + [[-0.03125, -0.015625, 0.0]], + [[-0.015625, -0.0078125, 0.0]], + [[-0.0078125, -0.00390625, 0.0]], + ] + ] + ), + ) From 56a8c5904cd9e5be2bf185b5b77de5adbcf3fb28 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Wed, 24 Jan 2024 01:08:52 +0200 Subject: [PATCH 12/24] Change bloom model API calls to much new alibi layer API --- keras_nlp/models/bloom/bloom_attention.py | 7 ++++++- keras_nlp/models/bloom/bloom_backbone.py | 1 + keras_nlp/models/bloom/bloom_decoder.py | 6 ++++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/keras_nlp/models/bloom/bloom_attention.py b/keras_nlp/models/bloom/bloom_attention.py index 366169d1fb..030ff6175a 100644 --- a/keras_nlp/models/bloom/bloom_attention.py +++ b/keras_nlp/models/bloom/bloom_attention.py @@ -23,6 +23,7 @@ class BloomAttention(keras.layers.Layer): def __init__( self, num_heads, + max_sequence_length, dropout=0.0, kernel_initializer="glorot_uniform", bias_initializer="zeros", @@ -30,6 +31,7 @@ def __init__( ): super().__init__(**kwargs) self.num_heads = num_heads + self.max_sequence_length = max_sequence_length self.dropout = dropout self.kernel_initializer = keras.initializers.get(kernel_initializer) self.bias_initializer = keras.initializers.get(bias_initializer) @@ -75,7 +77,10 @@ def build(self, inputs_shape): ) self._value_dense.build(inputs_shape) - self._alibi_layer = AlibiBias() + self._alibi_layer = AlibiBias( + max_sequence_length=self.max_sequence_length + ) + self._alibi_layer.build((batch_size, self.num_heads, 1, seq_length)) self._output_dense = keras.layers.Dense( hidden_dim, diff --git a/keras_nlp/models/bloom/bloom_backbone.py b/keras_nlp/models/bloom/bloom_backbone.py index e3d66998bc..cf0c76de9b 100644 --- a/keras_nlp/models/bloom/bloom_backbone.py +++ b/keras_nlp/models/bloom/bloom_backbone.py @@ -109,6 +109,7 @@ def __init__( x = BloomDecoder( num_heads=num_heads, intermediate_dim=intermediate_dim, + max_sequence_length=max_sequence_length, dropout=dropout, layer_norm_epsilon=layer_norm_epsilon, name=f"transformer_layer_{i}", diff --git a/keras_nlp/models/bloom/bloom_decoder.py b/keras_nlp/models/bloom/bloom_decoder.py index b3f8b80da7..b95eefbb2b 100644 --- a/keras_nlp/models/bloom/bloom_decoder.py +++ b/keras_nlp/models/bloom/bloom_decoder.py @@ -30,6 +30,7 @@ def __init__( self, num_heads, intermediate_dim, + max_sequence_length, dropout=0.0, layer_norm_epsilon=1e-5, kernel_initializer="glorot_uniform", @@ -40,6 +41,7 @@ def __init__( self.num_heads = num_heads self.intermediate_dim = intermediate_dim + self.max_sequence_length = max_sequence_length self.dropout = dropout self.layer_norm_epsilon = layer_norm_epsilon self.kernel_initializer = keras.initializers.get(kernel_initializer) @@ -64,6 +66,7 @@ def build(self, decoder_sequence_shape): self._self_attention_layer = BloomAttention( num_heads=self.num_heads, + max_sequence_length=self.max_sequence_length, dropout=self.dropout, kernel_initializer=clone_initializer(self.kernel_initializer), bias_initializer=clone_initializer(self.bias_initializer), @@ -184,6 +187,9 @@ def _compute_attention_mask( else causal_mask ) return decoder_mask + + def compute_output_shape(self, decoder_sequence_shape): + return decoder_sequence_shape def get_config(self): config = super().get_config() From 7408fd80c0b86b0885659dd64e53d36801c28d61 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Wed, 24 Jan 2024 01:09:33 +0200 Subject: [PATCH 13/24] Format the code --- keras_nlp/models/bloom/bloom_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/models/bloom/bloom_decoder.py b/keras_nlp/models/bloom/bloom_decoder.py index b95eefbb2b..06d0127361 100644 --- a/keras_nlp/models/bloom/bloom_decoder.py +++ b/keras_nlp/models/bloom/bloom_decoder.py @@ -187,7 +187,7 @@ def _compute_attention_mask( else causal_mask ) return decoder_mask - + def compute_output_shape(self, decoder_sequence_shape): return decoder_sequence_shape From 91c0e04aba806f9dafe074fcaa986aab6bc5b97e Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Wed, 24 Jan 2024 01:30:57 +0200 Subject: [PATCH 14/24] Add dtype kwarg for the layer weight --- keras_nlp/layers/modeling/alibi_bias.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index 8d5c1d67a7..8819d1f20b 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -82,7 +82,7 @@ def build(self, inputs_shape): self.max_sequence_length, ) self.alibi_bias = self.add_weight( - shape=alibi_bias_shape, trainable=False + shape=alibi_bias_shape, dtype=self.compute_dtype, trainable=False ) alibi_bias = self._get_alibi_bias(num_heads, self.max_sequence_length) alibi_bias = ops.reshape( From 410385f781298f9fe493ee6108fbef9ce2e88194 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Wed, 24 Jan 2024 01:33:38 +0200 Subject: [PATCH 15/24] Revert "Add dtype kwarg for the layer weight" This reverts commit 91c0e04aba806f9dafe074fcaa986aab6bc5b97e. --- keras_nlp/layers/modeling/alibi_bias.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index 8819d1f20b..8d5c1d67a7 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -82,7 +82,7 @@ def build(self, inputs_shape): self.max_sequence_length, ) self.alibi_bias = self.add_weight( - shape=alibi_bias_shape, dtype=self.compute_dtype, trainable=False + shape=alibi_bias_shape, trainable=False ) alibi_bias = self._get_alibi_bias(num_heads, self.max_sequence_length) alibi_bias = ops.reshape( From 8fc2616fbdd81db4b506435fde6cc610e1494a2a Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Wed, 24 Jan 2024 01:37:24 +0200 Subject: [PATCH 16/24] Cast after adding alibi bias --- keras_nlp/layers/modeling/alibi_bias.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index 8d5c1d67a7..0ab9cb5ede 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -102,20 +102,19 @@ def call(self, attention_scores): ) key_length = shape[-1] - - return ops.add( + attention_scores = ops.add( attention_scores, self.alibi_bias[..., self.max_sequence_length - key_length :], ) + return ops.cast(attention_scores, self.compute_dtype) + def _get_alibi_bias(self, num_heads, key_length): - slopes = ops.convert_to_tensor( - self._get_slopes(num_heads), dtype=self.compute_dtype - ) + slopes = ops.convert_to_tensor(self._get_slopes(num_heads), dtype=float) slopes = ops.expand_dims(slopes, 1) seq_range = ops.expand_dims(ops.arange(1 - key_length, 1), 0) - seq_range = ops.cast(seq_range, dtype=self.compute_dtype) + seq_range = ops.cast(seq_range, dtype=float) alibi_bias = ops.multiply(slopes, seq_range) From 8e750b7b0f7448a884ee1791ad4032694f96252f Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Thu, 25 Jan 2024 13:45:50 +0200 Subject: [PATCH 17/24] Return to compute ALibi bias at each call --- keras_nlp/layers/modeling/alibi_bias.py | 49 ++++------- keras_nlp/layers/modeling/alibi_bias_test.py | 85 +++----------------- keras_nlp/models/bloom/bloom_attention.py | 7 +- keras_nlp/models/bloom/bloom_backbone.py | 1 - keras_nlp/models/bloom/bloom_decoder.py | 6 -- 5 files changed, 24 insertions(+), 124 deletions(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index 0ab9cb5ede..a11422da3f 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -60,61 +60,41 @@ class AlibiBias(keras.layers.Layer): def __init__( self, - max_sequence_length, alibi_bias_max=8, **kwargs, ): super().__init__(**kwargs) - self.max_sequence_length = max_sequence_length self.alibi_bias_max = alibi_bias_max - def build(self, inputs_shape): - if len(inputs_shape) < 3: - raise ValueError( - "Expected `attention_scores` shape to be " - "`(..., num_heads, query_length, key_length)`." - f" Received shape={inputs_shape}" - ) - num_heads = inputs_shape[-3] - alibi_bias_shape = tuple([1 for _ in range(len(inputs_shape[:-3]))]) + ( - num_heads, - 1, - self.max_sequence_length, - ) - self.alibi_bias = self.add_weight( - shape=alibi_bias_shape, trainable=False - ) - alibi_bias = self._get_alibi_bias(num_heads, self.max_sequence_length) - alibi_bias = ops.reshape( - alibi_bias, - alibi_bias_shape, - ) - - self.alibi_bias.assign(alibi_bias) - def call(self, attention_scores): shape = ops.shape(attention_scores) if len(shape) < 3: raise ValueError( "Expected `attention_scores` shape to be " - "`(..., num_heads, query_length, key_length)`." - f" Received shape={shape}" + "`(..., num_heads, query_length, key_Length)`." + f" Recived shape={shape}" ) key_length = shape[-1] - attention_scores = ops.add( - attention_scores, - self.alibi_bias[..., self.max_sequence_length - key_length :], + num_heads = shape[-3] + + alibi_bias = self._get_alibi_bias(num_heads, key_length) + alibi_bias = ops.reshape( + alibi_bias, + tuple([1 for _ in range(len(shape[:-3]))]) + + (num_heads, 1, key_length), ) - return ops.cast(attention_scores, self.compute_dtype) + return ops.add(attention_scores, alibi_bias) def _get_alibi_bias(self, num_heads, key_length): - slopes = ops.convert_to_tensor(self._get_slopes(num_heads), dtype=float) + slopes = ops.convert_to_tensor( + self._get_slopes(num_heads), dtype=self.compute_dtype + ) slopes = ops.expand_dims(slopes, 1) seq_range = ops.expand_dims(ops.arange(1 - key_length, 1), 0) - seq_range = ops.cast(seq_range, dtype=float) + seq_range = ops.cast(seq_range, dtype=self.compute_dtype) alibi_bias = ops.multiply(slopes, seq_range) @@ -150,7 +130,6 @@ def get_config(self): config = super().get_config() config.update( { - "max_sequence_length": self.max_sequence_length, "alibi_bias_max": self.alibi_bias_max, } ) diff --git a/keras_nlp/layers/modeling/alibi_bias_test.py b/keras_nlp/layers/modeling/alibi_bias_test.py index 9cf6d5099c..b66ce18516 100644 --- a/keras_nlp/layers/modeling/alibi_bias_test.py +++ b/keras_nlp/layers/modeling/alibi_bias_test.py @@ -24,13 +24,11 @@ def test_layer_behaviors(self): alibi_bias_max = 8 batch_size = 4 num_heads = 8 - max_sequence_length = 10 query_length = 10 key_length = 10 self.run_layer_test( cls=AlibiBias, init_kwargs={ - "max_sequence_length": max_sequence_length, "alibi_bias_max": alibi_bias_max, }, input_data=random.uniform( @@ -42,8 +40,6 @@ def test_layer_behaviors(self): query_length, key_length, ), - expected_num_non_trainable_weights=1, - expected_num_non_trainable_variables=1, ) def test_float16_dtype(self): @@ -51,13 +47,8 @@ def test_float16_dtype(self): alibi_bias_max = 8 num_heads = 8 query_length = 5 - max_sequence_length = 10 key_length = 10 - test_layer = AlibiBias( - max_sequence_length=max_sequence_length, - alibi_bias_max=alibi_bias_max, - dtype="float16", - ) + test_layer = AlibiBias(alibi_bias_max=alibi_bias_max, dtype="float16") input_tensor = keras.Input(shape=(num_heads, query_length, key_length)) output_tensor = test_layer(input_tensor) @@ -69,12 +60,11 @@ def test_float16_dtype(self): self.assertEqual("float16", output_tensor.dtype) def test_dynamic_layer_output_shape(self): - max_sequence_length = 10 query_length = 10 key_length = 10 num_heads = 4 - test_layer = AlibiBias(max_sequence_length=max_sequence_length) + test_layer = AlibiBias() # Create a 4-dimensional input (the first dimension is implicit). input_tensor = keras.Input( shape=(None, num_heads, query_length, key_length) @@ -93,25 +83,13 @@ def test_dynamic_layer_output_shape(self): self.assertEqual(expected_output_shape, output_tensor.shape) def test_value_error_when_inputs_shape_is_less_than_3(self): - max_sequence_length = 12 with self.assertRaises(ValueError): - AlibiBias(max_sequence_length=max_sequence_length)( - random.uniform(shape=(12, 12)) - ) - - def test_key_length_is_less_than_max_sequence_length(self): - max_sequence_length = 20 - inputs_shape = (5, 4, 4, 12, 12) - inputs = random.uniform(shape=inputs_shape) - layer = AlibiBias(max_sequence_length=max_sequence_length) - outputs = layer(inputs) - self.assertEqual(inputs_shape, outputs.shape) + AlibiBias()(random.uniform(shape=(12, 12))) def test_num_heads_is_not_power_of_two(self): - max_sequence_length = 12 - inputs_shape = (1, 12, 12) + inputs_shape = (12, 12, 12) inputs = random.uniform(shape=inputs_shape) - layer = AlibiBias(max_sequence_length=max_sequence_length) + layer = AlibiBias() outputs = layer(inputs) self.assertEqual(inputs_shape, outputs.shape) @@ -120,12 +98,9 @@ def test_correct_output(self): num_heads = 8 query_length = 1 key_length = 3 - max_sequence_length = 3 input_shape = (batch_size, num_heads, query_length, key_length) input_tensor = ops.zeros(input_shape) - layer = AlibiBias( - max_sequence_length=max_sequence_length, - ) + layer = AlibiBias() output_tensor = layer(input_tensor) print(output_tensor) self.assertAllClose( @@ -151,12 +126,9 @@ def test_correct_output_num_heads_not_power_of_two(self): num_heads = 14 query_length = 1 key_length = 3 - max_sequence_length = 3 input_shape = (batch_size, num_heads, query_length, key_length) input_tensor = ops.zeros(input_shape) - layer = AlibiBias( - max_sequence_length=max_sequence_length, - ) + layer = AlibiBias() output_tensor = layer(input_tensor) print(output_tensor) self.assertAllClose( @@ -189,13 +161,9 @@ def test_correct_output_alibi_bias_max(self): num_heads = 2 query_length = 1 key_length = 3 - max_sequence_length = 3 input_shape = (batch_size, num_heads, query_length, key_length) input_tensor = ops.zeros(input_shape) - layer = AlibiBias( - max_sequence_length=max_sequence_length, - alibi_bias_max=alibi_bias_max, - ) + layer = AlibiBias(alibi_bias_max=alibi_bias_max) output_tensor = layer(input_tensor) print(output_tensor) self.assertAllClose( @@ -218,13 +186,9 @@ def test_correct_output_alibi_bias_max_num_heads_not_power_of_two( num_heads = 3 query_length = 1 key_length = 3 - max_sequence_length = 3 input_shape = (batch_size, num_heads, query_length, key_length) input_tensor = ops.zeros(input_shape) - layer = AlibiBias( - max_sequence_length=max_sequence_length, - alibi_bias_max=alibi_bias_max, - ) + layer = AlibiBias(alibi_bias_max=alibi_bias_max) output_tensor = layer(input_tensor) print(output_tensor) self.assertAllClose( @@ -239,34 +203,3 @@ def test_correct_output_alibi_bias_max_num_heads_not_power_of_two( ] ), ) - - def test_correct_output_key_length_smaller_than_max_sequence_length(self): - batch_size = 1 - num_heads = 8 - query_length = 1 - key_length = 3 - max_sequence_length = 10 - input_shape = (batch_size, num_heads, query_length, key_length) - input_tensor = ops.zeros(input_shape) - layer = AlibiBias( - max_sequence_length=max_sequence_length, - ) - output_tensor = layer(input_tensor) - print(output_tensor) - self.assertAllClose( - output_tensor, - ops.convert_to_tensor( - [ - [ - [[-1.0, -0.5, 0.0]], - [[-0.5, -0.25, 0.0]], - [[-0.25, -0.125, 0.0]], - [[-0.125, -0.0625, 0.0]], - [[-0.0625, -0.03125, 0.0]], - [[-0.03125, -0.015625, 0.0]], - [[-0.015625, -0.0078125, 0.0]], - [[-0.0078125, -0.00390625, 0.0]], - ] - ] - ), - ) diff --git a/keras_nlp/models/bloom/bloom_attention.py b/keras_nlp/models/bloom/bloom_attention.py index 030ff6175a..366169d1fb 100644 --- a/keras_nlp/models/bloom/bloom_attention.py +++ b/keras_nlp/models/bloom/bloom_attention.py @@ -23,7 +23,6 @@ class BloomAttention(keras.layers.Layer): def __init__( self, num_heads, - max_sequence_length, dropout=0.0, kernel_initializer="glorot_uniform", bias_initializer="zeros", @@ -31,7 +30,6 @@ def __init__( ): super().__init__(**kwargs) self.num_heads = num_heads - self.max_sequence_length = max_sequence_length self.dropout = dropout self.kernel_initializer = keras.initializers.get(kernel_initializer) self.bias_initializer = keras.initializers.get(bias_initializer) @@ -77,10 +75,7 @@ def build(self, inputs_shape): ) self._value_dense.build(inputs_shape) - self._alibi_layer = AlibiBias( - max_sequence_length=self.max_sequence_length - ) - self._alibi_layer.build((batch_size, self.num_heads, 1, seq_length)) + self._alibi_layer = AlibiBias() self._output_dense = keras.layers.Dense( hidden_dim, diff --git a/keras_nlp/models/bloom/bloom_backbone.py b/keras_nlp/models/bloom/bloom_backbone.py index cf0c76de9b..e3d66998bc 100644 --- a/keras_nlp/models/bloom/bloom_backbone.py +++ b/keras_nlp/models/bloom/bloom_backbone.py @@ -109,7 +109,6 @@ def __init__( x = BloomDecoder( num_heads=num_heads, intermediate_dim=intermediate_dim, - max_sequence_length=max_sequence_length, dropout=dropout, layer_norm_epsilon=layer_norm_epsilon, name=f"transformer_layer_{i}", diff --git a/keras_nlp/models/bloom/bloom_decoder.py b/keras_nlp/models/bloom/bloom_decoder.py index 06d0127361..b3f8b80da7 100644 --- a/keras_nlp/models/bloom/bloom_decoder.py +++ b/keras_nlp/models/bloom/bloom_decoder.py @@ -30,7 +30,6 @@ def __init__( self, num_heads, intermediate_dim, - max_sequence_length, dropout=0.0, layer_norm_epsilon=1e-5, kernel_initializer="glorot_uniform", @@ -41,7 +40,6 @@ def __init__( self.num_heads = num_heads self.intermediate_dim = intermediate_dim - self.max_sequence_length = max_sequence_length self.dropout = dropout self.layer_norm_epsilon = layer_norm_epsilon self.kernel_initializer = keras.initializers.get(kernel_initializer) @@ -66,7 +64,6 @@ def build(self, decoder_sequence_shape): self._self_attention_layer = BloomAttention( num_heads=self.num_heads, - max_sequence_length=self.max_sequence_length, dropout=self.dropout, kernel_initializer=clone_initializer(self.kernel_initializer), bias_initializer=clone_initializer(self.bias_initializer), @@ -188,9 +185,6 @@ def _compute_attention_mask( ) return decoder_mask - def compute_output_shape(self, decoder_sequence_shape): - return decoder_sequence_shape - def get_config(self): config = super().get_config() config.update( From a44c6e844f80c5d4c686633f3143bcb18272eda2 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Thu, 25 Jan 2024 13:52:18 +0200 Subject: [PATCH 18/24] Add compute output shape method for bloom decoder --- keras_nlp/models/bloom/bloom_decoder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/keras_nlp/models/bloom/bloom_decoder.py b/keras_nlp/models/bloom/bloom_decoder.py index b3f8b80da7..b1f89c7e0e 100644 --- a/keras_nlp/models/bloom/bloom_decoder.py +++ b/keras_nlp/models/bloom/bloom_decoder.py @@ -202,3 +202,6 @@ def get_config(self): } ) return config + + def compute_output_shape(self, decoder_sequence_shape): + return decoder_sequence_shape From 7f8cc0ff09860541046843bcb4c5222283dfac78 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Thu, 25 Jan 2024 14:55:10 +0200 Subject: [PATCH 19/24] Force shape to be (batch_size, num_heads, query_length, key_length) --- keras_nlp/layers/modeling/alibi_bias.py | 57 ++++++++++---------- keras_nlp/layers/modeling/alibi_bias_test.py | 4 +- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index a11422da3f..c66b9f4320 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -20,38 +20,44 @@ @keras_nlp_export("keras_nlp.layers.AlibiBias") class AlibiBias(keras.layers.Layer): - """A layer that add the alibi bias to attention scores + """A layer that adds the alibi bias to attention scores. - This layer generates a linear, non-learned bias. Defined and formalized in + This layer adds the alibi bias to the attention scores. alibi bias is a + linear, non-learned bias. Defined and formalized in [Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation](https://arxiv.org/abs/2108.12409). - Takes as input an attention score. This layer will return the attention + This layer Takes as input the attention scores. and returns the attention scores after adding the alibi bias to it. The output will have the same shape as the input. Args: alibi_bias_max: int. This value will be used to compute the slope of - each head. The heads slopes is a geometric sequence that starts at + each head. The heads' slopes is a geometric sequence that starts at `2**(-alibi_bias_max/num_heads)` and uses that same value as its ratio. Defaults to 8. Call arguments: attention_scores: The result of multipying the query and the key of the - multi head attention of the transformer. The shape must be greater - than or equal to 3 with the last 3 dimensions equal to - `(num_heads, query_length, key_length)`. + multi-head attention layer of the transformer to add alibi bias to + it. with shape `(batch_size, num_heads, query_length, key_length)`. Examples: ```python - # create a simple layer that takes token embeddings as input and generates - # the alibi tensor - seq_len = 100 - vocab_size = 1000 - embedding_dim = 32 - inputs = keras.Input((seq_len,), dtype="float32") - embedding = keras.layers.Embedding( - input_dim=vocab_size, output_dim=embedding_dim - )(inputs) - alibi_bias = keras_nlp.layers.AlibiBias(num_heads=8)(embedding) + query_length = 100 + key_length = 100 + num_heads = 8 + batch_size = 4 + hidden_dim = 12 + + # Create new alibi layer. + alibi_layer = AlibiBias() + + query = np.zeros((batch_size, num_heads, query_length, hidden_dim)) + key = np.zeros((batch_size, num_heads, hidden_dim, key_length)) + + attention_scores = ops.matmul(query, key) + + # Add alibi bias to attention scores. + attention_scores = alibi_layer(attention_scores) ``` References: @@ -68,10 +74,10 @@ def __init__( def call(self, attention_scores): shape = ops.shape(attention_scores) - if len(shape) < 3: + if len(shape) != 4: raise ValueError( "Expected `attention_scores` shape to be " - "`(..., num_heads, query_length, key_Length)`." + "`(batch_size, num_heads, query_length, key_Length)`." f" Recived shape={shape}" ) @@ -79,11 +85,6 @@ def call(self, attention_scores): num_heads = shape[-3] alibi_bias = self._get_alibi_bias(num_heads, key_length) - alibi_bias = ops.reshape( - alibi_bias, - tuple([1 for _ in range(len(shape[:-3]))]) - + (num_heads, 1, key_length), - ) return ops.add(attention_scores, alibi_bias) @@ -97,13 +98,13 @@ def _get_alibi_bias(self, num_heads, key_length): seq_range = ops.cast(seq_range, dtype=self.compute_dtype) alibi_bias = ops.multiply(slopes, seq_range) + alibi_bias = ops.expand_dims(alibi_bias, 1) - # Expand on query dimension - # return shape is `(num_heads, 1, key_length)` - return ops.expand_dims(alibi_bias, 1) + # return shape is `(1, num_heads, 1, key_length)` + return ops.expand_dims(alibi_bias, 0) def _get_slopes(self, num_heads): - # this function is adopted from Alibi original implementation + # this function is adopted from Alibi original implementation. # https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 def get_slopes_power_of_2(n): start = 2 ** ( diff --git a/keras_nlp/layers/modeling/alibi_bias_test.py b/keras_nlp/layers/modeling/alibi_bias_test.py index b66ce18516..5881f9a485 100644 --- a/keras_nlp/layers/modeling/alibi_bias_test.py +++ b/keras_nlp/layers/modeling/alibi_bias_test.py @@ -82,12 +82,12 @@ def test_dynamic_layer_output_shape(self): ) self.assertEqual(expected_output_shape, output_tensor.shape) - def test_value_error_when_inputs_shape_is_less_than_3(self): + def test_value_error_when_inputs_shape_is_not_4(self): with self.assertRaises(ValueError): AlibiBias()(random.uniform(shape=(12, 12))) def test_num_heads_is_not_power_of_two(self): - inputs_shape = (12, 12, 12) + inputs_shape = (1, 12, 12, 12) inputs = random.uniform(shape=inputs_shape) layer = AlibiBias() outputs = layer(inputs) From 9440bd646da2923b9aea926c58cfdc9796d67544 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Thu, 25 Jan 2024 14:55:52 +0200 Subject: [PATCH 20/24] Format the code --- keras_nlp/layers/modeling/alibi_bias.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index c66b9f4320..96644ecc66 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -22,7 +22,7 @@ class AlibiBias(keras.layers.Layer): """A layer that adds the alibi bias to attention scores. - This layer adds the alibi bias to the attention scores. alibi bias is a + This layer adds the alibi bias to the attention scores. alibi bias is a linear, non-learned bias. Defined and formalized in [Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation](https://arxiv.org/abs/2108.12409). @@ -37,7 +37,7 @@ class AlibiBias(keras.layers.Layer): ratio. Defaults to 8. Call arguments: attention_scores: The result of multipying the query and the key of the - multi-head attention layer of the transformer to add alibi bias to + multi-head attention layer of the transformer to add alibi bias to it. with shape `(batch_size, num_heads, query_length, key_length)`. Examples: From 795eacdde51efbea8642bdf47287ab31f61ed1fc Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Thu, 25 Jan 2024 15:14:39 +0200 Subject: [PATCH 21/24] Fix documentation --- keras_nlp/layers/modeling/alibi_bias.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index 96644ecc66..a951c0fcb7 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -22,23 +22,23 @@ class AlibiBias(keras.layers.Layer): """A layer that adds the alibi bias to attention scores. - This layer adds the alibi bias to the attention scores. alibi bias is a + This layer adds the alibi bias to the attention scores. Alibi bias is a linear, non-learned bias. Defined and formalized in [Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation](https://arxiv.org/abs/2108.12409). - This layer Takes as input the attention scores. and returns the attention + This layer takes as input the attention scores. and returns the attention scores after adding the alibi bias to it. The output will have the same shape as the input. Args: alibi_bias_max: int. This value will be used to compute the slope of - each head. The heads' slopes is a geometric sequence that starts at + each head. The heads' slopes are a geometric sequence that starts at `2**(-alibi_bias_max/num_heads)` and uses that same value as its ratio. Defaults to 8. Call arguments: attention_scores: The result of multipying the query and the key of the multi-head attention layer of the transformer to add alibi bias to - it. with shape `(batch_size, num_heads, query_length, key_length)`. + it. With shape `(batch_size, num_heads, query_length, key_length)`. Examples: ```python From 071189f3495642b516b25a83c8c6d5d2ee287daa Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Thu, 25 Jan 2024 15:18:14 +0200 Subject: [PATCH 22/24] Fix the example --- keras_nlp/layers/modeling/alibi_bias.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index a951c0fcb7..8a66ad05af 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -42,19 +42,19 @@ class AlibiBias(keras.layers.Layer): Examples: ```python - query_length = 100 - key_length = 100 - num_heads = 8 - batch_size = 4 - hidden_dim = 12 + query_length = 10 + key_length = 10 + num_heads = 4 + batch_size = 2 + hidden_dim = 8 # Create new alibi layer. - alibi_layer = AlibiBias() + alibi_layer = keras_nlp.layers.AlibiBias() query = np.zeros((batch_size, num_heads, query_length, hidden_dim)) key = np.zeros((batch_size, num_heads, hidden_dim, key_length)) - attention_scores = ops.matmul(query, key) + attention_scores = keras.ops.matmul(query, key) # Add alibi bias to attention scores. attention_scores = alibi_layer(attention_scores) From f6bf10d2999f19738b685f6bdd8eefa6fa31e9c0 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Thu, 25 Jan 2024 15:21:26 +0200 Subject: [PATCH 23/24] Fix tensorflow2 test fail --- keras_nlp/layers/modeling/alibi_bias_test.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/keras_nlp/layers/modeling/alibi_bias_test.py b/keras_nlp/layers/modeling/alibi_bias_test.py index 5881f9a485..8d13aac9bf 100644 --- a/keras_nlp/layers/modeling/alibi_bias_test.py +++ b/keras_nlp/layers/modeling/alibi_bias_test.py @@ -67,14 +67,13 @@ def test_dynamic_layer_output_shape(self): test_layer = AlibiBias() # Create a 4-dimensional input (the first dimension is implicit). input_tensor = keras.Input( - shape=(None, num_heads, query_length, key_length) + shape=(num_heads, query_length, key_length) ) output_tensor = test_layer(input_tensor) # the output is expected to be the same as the input shape in all - # dimensions - but may be None if the input shape is None there. + # dimensions. expected_output_shape = ( - None, None, num_heads, query_length, From 4f8fcd05ed207ace83473b1a9f7d2e77670f66d6 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Thu, 25 Jan 2024 15:21:49 +0200 Subject: [PATCH 24/24] Fix tensorflow2 test fail --- keras_nlp/layers/modeling/alibi_bias_test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/keras_nlp/layers/modeling/alibi_bias_test.py b/keras_nlp/layers/modeling/alibi_bias_test.py index 8d13aac9bf..120c7622b1 100644 --- a/keras_nlp/layers/modeling/alibi_bias_test.py +++ b/keras_nlp/layers/modeling/alibi_bias_test.py @@ -66,9 +66,7 @@ def test_dynamic_layer_output_shape(self): test_layer = AlibiBias() # Create a 4-dimensional input (the first dimension is implicit). - input_tensor = keras.Input( - shape=(num_heads, query_length, key_length) - ) + input_tensor = keras.Input(shape=(num_heads, query_length, key_length)) output_tensor = test_layer(input_tensor) # the output is expected to be the same as the input shape in all