Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a76a91b
Add AlibiBias layer
abuelnasr0 Jan 18, 2024
3f3d968
Add example
abuelnasr0 Jan 18, 2024
f1536df
Convert layer to recieve attn_scores and add the alibi bias to it.
abuelnasr0 Jan 22, 2024
9c891ae
Change layer logic
abuelnasr0 Jan 22, 2024
3676476
Add layer test
abuelnasr0 Jan 22, 2024
358e6f0
Format the code
abuelnasr0 Jan 22, 2024
0c949c4
Fix seq_range creation to be int range
abuelnasr0 Jan 22, 2024
5f1ebc1
Change bloom model to use alibi bias
abuelnasr0 Jan 22, 2024
aa3b15f
Format the code
abuelnasr0 Jan 22, 2024
f462ee2
Remove print function
abuelnasr0 Jan 23, 2024
726946d
Change logic to only compute alibi bias once
abuelnasr0 Jan 23, 2024
56a8c59
Change bloom model API calls to much new alibi layer API
abuelnasr0 Jan 23, 2024
7408fd8
Format the code
abuelnasr0 Jan 23, 2024
91c0e04
Add dtype kwarg for the layer weight
abuelnasr0 Jan 23, 2024
410385f
Revert "Add dtype kwarg for the layer weight"
abuelnasr0 Jan 23, 2024
8fc2616
Cast after adding alibi bias
abuelnasr0 Jan 23, 2024
8e750b7
Return to compute ALibi bias at each call
abuelnasr0 Jan 25, 2024
a44c6e8
Add compute output shape method for bloom decoder
abuelnasr0 Jan 25, 2024
7f8cc0f
Force shape to be (batch_size, num_heads, query_length, key_length)
abuelnasr0 Jan 25, 2024
9440bd6
Format the code
abuelnasr0 Jan 25, 2024
795eacd
Fix documentation
abuelnasr0 Jan 25, 2024
071189f
Fix the example
abuelnasr0 Jan 25, 2024
f6bf10d
Fix tensorflow2 test fail
abuelnasr0 Jan 25, 2024
4f8fcd0
Fix tensorflow2 test fail
abuelnasr0 Jan 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions keras_nlp/layers/modeling/alibi_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.backend import ops


@keras_nlp_export("keras_nlp.layers.AlibiBias")
class AlibiBias(keras.layers.Layer):
"""A layer that adds the alibi bias to attention scores.

This layer adds the alibi bias to the attention scores. Alibi bias is a
linear, non-learned bias. Defined and formalized in
[Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation](https://arxiv.org/abs/2108.12409).

This layer takes as input the attention scores. and returns the attention
scores after adding the alibi bias to it. The output will have the same
shape as the input.

Args:
alibi_bias_max: int. This value will be used to compute the slope of
each head. The heads' slopes are a geometric sequence that starts at
`2**(-alibi_bias_max/num_heads)` and uses that same value as its
ratio. Defaults to 8.
Call arguments:
attention_scores: The result of multipying the query and the key of the
multi-head attention layer of the transformer to add alibi bias to
it. With shape `(batch_size, num_heads, query_length, key_length)`.

Examples:
```python
query_length = 10
key_length = 10
num_heads = 4
batch_size = 2
hidden_dim = 8

# Create new alibi layer.
alibi_layer = keras_nlp.layers.AlibiBias()

query = np.zeros((batch_size, num_heads, query_length, hidden_dim))
key = np.zeros((batch_size, num_heads, hidden_dim, key_length))

attention_scores = keras.ops.matmul(query, key)

# Add alibi bias to attention scores.
attention_scores = alibi_layer(attention_scores)
```

References:
- [Press et al., 2021](https://arxiv.org/abs/2108.12409)
"""

def __init__(
self,
alibi_bias_max=8,
**kwargs,
):
super().__init__(**kwargs)
self.alibi_bias_max = alibi_bias_max

def call(self, attention_scores):
shape = ops.shape(attention_scores)
if len(shape) != 4:
raise ValueError(
"Expected `attention_scores` shape to be "
"`(batch_size, num_heads, query_length, key_Length)`."
f" Recived shape={shape}"
)

key_length = shape[-1]
num_heads = shape[-3]

alibi_bias = self._get_alibi_bias(num_heads, key_length)

return ops.add(attention_scores, alibi_bias)

def _get_alibi_bias(self, num_heads, key_length):
slopes = ops.convert_to_tensor(
self._get_slopes(num_heads), dtype=self.compute_dtype
)
slopes = ops.expand_dims(slopes, 1)

seq_range = ops.expand_dims(ops.arange(1 - key_length, 1), 0)
seq_range = ops.cast(seq_range, dtype=self.compute_dtype)

alibi_bias = ops.multiply(slopes, seq_range)
alibi_bias = ops.expand_dims(alibi_bias, 1)

# return shape is `(1, num_heads, 1, key_length)`
return ops.expand_dims(alibi_bias, 0)

def _get_slopes(self, num_heads):
# this function is adopted from Alibi original implementation.
# https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
def get_slopes_power_of_2(n):
start = 2 ** (
-(2 ** -(math.log2(n) - math.log2(self.alibi_bias_max)))
)
ratio = start
return [start * ratio**i for i in range(n)]

if math.log2(num_heads).is_integer():
return get_slopes_power_of_2(num_heads)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
return (
get_slopes_power_of_2(closest_power_of_2)
+ self._get_slopes(2 * closest_power_of_2)[0::2][
: num_heads - closest_power_of_2
]
)

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self):
config = super().get_config()
config.update(
{
"alibi_bias_max": self.alibi_bias_max,
}
)
return config
202 changes: 202 additions & 0 deletions keras_nlp/layers/modeling/alibi_bias_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.backend import random
from keras_nlp.layers.modeling.alibi_bias import AlibiBias
from keras_nlp.tests.test_case import TestCase


class AlibiBiasTest(TestCase):
def test_layer_behaviors(self):
alibi_bias_max = 8
batch_size = 4
num_heads = 8
query_length = 10
key_length = 10
self.run_layer_test(
cls=AlibiBias,
init_kwargs={
"alibi_bias_max": alibi_bias_max,
},
input_data=random.uniform(
shape=(batch_size, num_heads, query_length, key_length)
),
expected_output_shape=(
batch_size,
num_heads,
query_length,
key_length,
),
)

def test_float16_dtype(self):
# Create a 4-dimensional input (the first dimension is implicit).
alibi_bias_max = 8
num_heads = 8
query_length = 5
key_length = 10
test_layer = AlibiBias(alibi_bias_max=alibi_bias_max, dtype="float16")
input_tensor = keras.Input(shape=(num_heads, query_length, key_length))
output_tensor = test_layer(input_tensor)

# the output is expected to be the same as the input shape in all
# dimensions. here, the first dimension is implicit and is for batch
expected_output_shape = (None, num_heads, query_length, key_length)
self.assertEqual(expected_output_shape, output_tensor.shape)
# The default output dtype for this layer should be "float32".
self.assertEqual("float16", output_tensor.dtype)

def test_dynamic_layer_output_shape(self):
query_length = 10
key_length = 10
num_heads = 4

test_layer = AlibiBias()
# Create a 4-dimensional input (the first dimension is implicit).
input_tensor = keras.Input(shape=(num_heads, query_length, key_length))
output_tensor = test_layer(input_tensor)

# the output is expected to be the same as the input shape in all
# dimensions.
expected_output_shape = (
None,
num_heads,
query_length,
key_length,
)
self.assertEqual(expected_output_shape, output_tensor.shape)

def test_value_error_when_inputs_shape_is_not_4(self):
with self.assertRaises(ValueError):
AlibiBias()(random.uniform(shape=(12, 12)))

def test_num_heads_is_not_power_of_two(self):
inputs_shape = (1, 12, 12, 12)
inputs = random.uniform(shape=inputs_shape)
layer = AlibiBias()
outputs = layer(inputs)
self.assertEqual(inputs_shape, outputs.shape)

def test_correct_output(self):
batch_size = 1
num_heads = 8
query_length = 1
key_length = 3
input_shape = (batch_size, num_heads, query_length, key_length)
input_tensor = ops.zeros(input_shape)
layer = AlibiBias()
output_tensor = layer(input_tensor)
print(output_tensor)
self.assertAllClose(
output_tensor,
ops.convert_to_tensor(
[
[
[[-1.0, -0.5, 0.0]],
[[-0.5, -0.25, 0.0]],
[[-0.25, -0.125, 0.0]],
[[-0.125, -0.0625, 0.0]],
[[-0.0625, -0.03125, 0.0]],
[[-0.03125, -0.015625, 0.0]],
[[-0.015625, -0.0078125, 0.0]],
[[-0.0078125, -0.00390625, 0.0]],
]
]
),
)

def test_correct_output_num_heads_not_power_of_two(self):
batch_size = 1
num_heads = 14
query_length = 1
key_length = 3
input_shape = (batch_size, num_heads, query_length, key_length)
input_tensor = ops.zeros(input_shape)
layer = AlibiBias()
output_tensor = layer(input_tensor)
print(output_tensor)
self.assertAllClose(
output_tensor,
ops.convert_to_tensor(
[
[
[[-1.0, -0.5, 0.0]],
[[-0.5, -0.25, 0.0]],
[[-0.25, -0.125, 0.0]],
[[-0.125, -0.0625, 0.0]],
[[-0.0625, -0.03125, 0.0]],
[[-0.03125, -0.015625, 0.0]],
[[-0.015625, -0.0078125, 0.0]],
[[-0.0078125, -0.00390625, 0.0]],
[[-1.4142135, -0.70710677, 0.0]],
[[-0.70710677, -0.35355338, 0.0]],
[[-0.35355338, -0.17677669, 0.0]],
[[-0.17677669, -0.08838835, 0.0]],
[[-0.08838835, -0.04419417, 0.0]],
[[-0.04419417, -0.02209709, 0.0]],
]
]
),
)

def test_correct_output_alibi_bias_max(self):
alibi_bias_max = 12
batch_size = 1
num_heads = 2
query_length = 1
key_length = 3
input_shape = (batch_size, num_heads, query_length, key_length)
input_tensor = ops.zeros(input_shape)
layer = AlibiBias(alibi_bias_max=alibi_bias_max)
output_tensor = layer(input_tensor)
print(output_tensor)
self.assertAllClose(
output_tensor,
ops.convert_to_tensor(
[
[
[[-0.03125, -0.015625, 0.0]],
[[-0.00048828, -0.00024414, 0.0]],
]
]
),
)

def test_correct_output_alibi_bias_max_num_heads_not_power_of_two(
self,
):
alibi_bias_max = 6
batch_size = 1
num_heads = 3
query_length = 1
key_length = 3
input_shape = (batch_size, num_heads, query_length, key_length)
input_tensor = ops.zeros(input_shape)
layer = AlibiBias(alibi_bias_max=alibi_bias_max)
output_tensor = layer(input_tensor)
print(output_tensor)
self.assertAllClose(
output_tensor,
ops.convert_to_tensor(
[
[
[[-0.25, -0.125, 0.0]],
[[-0.03125, -0.015625, 0.0]],
[[-0.70710677, -0.35355338, 0.0]],
]
]
),
)
Loading