Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions keras_nlp/layers/modeling/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,45 +85,52 @@ def __init__(
self.built = True

def call(self, inputs, start_index=0):
rotary_dim = ops.shape(inputs)[-1]
cos_emb, sin_emb = self._compute_cos_sin_embedding(
inputs, rotary_dim, start_index
)
cos_emb, sin_emb = self._compute_cos_sin_embedding(inputs, start_index)
return self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb)

def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
x1, x2 = ops.split(tensor, 2, axis=self.feature_axis)
half_rot_tensor = ops.concatenate((-x2, x1), axis=self.feature_axis)
return (tensor * cos_emb) + (half_rot_tensor * sin_emb)

def _compute_cos_sin_embedding(self, x, rotary_dim, start_index):
freq_range = ops.arange(0, rotary_dim, 2)
freq_range = ops.cast(freq_range, self.compute_dtype)
freq_range = freq_range / ops.cast(
self.scaling_factor, self.compute_dtype
)
inverse_freq = 1.0 / (
self.max_wavelength
** (freq_range / ops.cast(rotary_dim, self.compute_dtype))
)
seq_len = ops.shape(x)[self.sequence_axis]
tensor = ops.cast(ops.arange(seq_len), self.compute_dtype) + start_index
tensor = ops.cast(tensor, dtype=inverse_freq.dtype)
freq = ops.einsum("i, j -> ij", tensor, inverse_freq)
embedding = ops.concatenate((freq, freq), axis=self.feature_axis)

def _compute_cos_sin_embedding(self, inputs, start_index=0):
def get_axis(axis):
return axis if axis > 0 else len(x.shape) + axis
return axis if axis > 0 else len(inputs.shape) + axis

feature_axis = get_axis(self.feature_axis)
sequence_axis = get_axis(self.sequence_axis)

for axis in range(len(x.shape)):
rotary_dim = ops.shape(inputs)[feature_axis]
inverse_freq = self._get_inverse_freq(rotary_dim)

seq_len = ops.shape(inputs)[self.sequence_axis]
tensor = ops.cast(ops.arange(seq_len), self.compute_dtype) + start_index

tensor = ops.cast(tensor, dtype=inverse_freq.dtype)
freq = ops.einsum("i,j->ij", tensor, inverse_freq)
embedding = ops.concatenate((freq, freq), axis=-1)

# Reshape the embedding to be broadcastable with input shape.
if feature_axis < sequence_axis:
embedding = ops.transpose(embedding)
for axis in range(len(inputs.shape)):
if axis != sequence_axis and axis != feature_axis:
embedding = ops.expand_dims(embedding, axis)

return ops.cos(embedding), ops.sin(embedding)

def _get_inverse_freq(self, rotary_dim):
freq_range = ops.arange(0, rotary_dim, 2)
freq_range = ops.cast(freq_range, self.compute_dtype)
freq_range = freq_range / ops.cast(
self.scaling_factor, self.compute_dtype
)
inverse_freq = 1.0 / (
self.max_wavelength
** (freq_range / ops.cast(rotary_dim, self.compute_dtype))
)
return inverse_freq

def get_config(self):
config = super().get_config()
config.update(
Expand Down
12 changes: 12 additions & 0 deletions keras_nlp/layers/modeling/rotary_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ def test_start_index(self):
)
self.assertAllClose(full_output, sequential_output)

def test_permuted_axes(self):
batch_size, seq_length, feature_size = 2, 3, 4
data = random.uniform(shape=(batch_size, seq_length, feature_size))
layer = RotaryEmbedding(seq_length)
outputs = layer(data)
permuted_data = ops.transpose(data, (0, 2, 1))
permuted_layer = RotaryEmbedding(
seq_length, sequence_axis=-1, feature_axis=-2
)
permuted_outputs = permuted_layer(permuted_data)
self.assertAllClose(outputs, ops.transpose(permuted_outputs, (0, 2, 1)))

def test_float16_dtype(self):
embedding_layer = RotaryEmbedding(dtype="float16")
seq_length = 100
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
GPTNeoXPreprocessor,
)
from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
from keras_nlp.models.llama.llama_backbone import LlamaBackbone
from keras_nlp.models.mistral.mistral_backbone import MistralBackbone
from keras_nlp.models.opt.opt_backbone import OPTBackbone
from keras_nlp.models.opt.opt_causal_lm import OPTCausalLM
Expand Down
13 changes: 13 additions & 0 deletions keras_nlp/models/llama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
201 changes: 201 additions & 0 deletions keras_nlp/models/llama/llama_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_nlp.utils.keras_utils import clone_initializer


class LlamaAttention(keras.layers.Layer):
"""Grouped query attention for Llama models"""

def __init__(
self,
num_query_heads,
num_key_value_heads,
rope_scaling_factor=1.0,
kernel_initializer="glorot_uniform",
rope_max_wavelength=10000,
max_sequence_length=512,
**kwargs,
):
super().__init__(**kwargs)
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads

self.num_key_value_groups = num_query_heads // num_key_value_heads

self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.max_sequence_length = max_sequence_length

self.rope_scaling_factor = rope_scaling_factor
self.rope_max_wavelength = rope_max_wavelength

def build(self, inputs_shape):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as mistral... Consider something like this, where we collocate all einsum equations in build, and we add a nice key at the top. Helps readability.

https://github.com/keras-team/keras/blob/master/keras/layers/attention/grouped_query_attention.py#L124-L167

(ok if we want to punt on this for this pr)

Copy link
Collaborator Author

@kanpuriyanawab kanpuriyanawab Nov 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks good! Added.

self.hidden_dim = inputs_shape[-1]
self.attn_head_size = self.hidden_dim // self.num_query_heads

# Einsum variables:
# b = batch size
# q = query length
# k = key/value length
# m = model dim
# u = num query heads
# v = num key/value heads
# h = head dim
self._query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self.num_query_heads, self.attn_head_size),
kernel_initializer=clone_initializer(self.kernel_initializer),
name="query",
)
self._query_dense.build(inputs_shape)
self._key_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(None, self.num_key_value_heads, self.attn_head_size),
kernel_initializer=clone_initializer(self.kernel_initializer),
name="key",
)
self._key_dense.build(inputs_shape)

self._value_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(None, self.num_key_value_heads, self.attn_head_size),
kernel_initializer=clone_initializer(self.kernel_initializer),
name="value",
)
self._value_dense.build(inputs_shape)

self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax")

self._output_dense = keras.layers.EinsumDense(
equation="bqm,mh->bqh",
output_shape=(None, self.hidden_dim),
kernel_initializer=clone_initializer(self.kernel_initializer),
name="attention_output",
)
self._output_dense.build(inputs_shape)

self._rotary_embedding_layer = RotaryEmbedding(
max_wavelength=self.rope_max_wavelength,
scaling_factor=self.rope_scaling_factor,
)
self._rotary_embedding_layer.build(inputs_shape)

self.built = True

def call(
self,
hidden_states,
attention_mask=None,
cache=None,
cache_update_index=None,
):
query = self._query_dense(hidden_states)

if cache is not None:
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update = self._key_dense(hidden_states)
value_update = self._value_dense(hidden_states)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)
key = self._key_dense(hidden_states)
value = self._value_dense(hidden_states)

query = self._rotary_embedding_layer(query)
key = self._rotary_embedding_layer(key)

key = ops.tile(key, [1, 1, self.num_key_value_groups, 1])
value = ops.tile(value, [1, 1, self.num_key_value_groups, 1])

attention_output, attention_scores = self._compute_attention(
query, key, value, attention_mask
)

attention_output_shape = ops.shape(attention_output)

attention_output = ops.reshape(
attention_output,
[
attention_output_shape[0],
attention_output_shape[1],
self.hidden_dim,
],
)

attention_output = self._output_dense(attention_output)

if cache is not None:
return (attention_output, cache)
return attention_output

def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
mask_expansion_axis = -3
for _ in range(
len(attention_scores.shape) - len(attention_mask.shape)
):
attention_mask = ops.expand_dims(
attention_mask, axis=mask_expansion_axis
)
Comment on lines +159 to +165
Copy link
Contributor

@tirthasheshpatel tirthasheshpatel Nov 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the inputs are constrained to be 3 dimensional, we can simplify this as:

Suggested change
mask_expansion_axis = -3
for _ in range(
len(attention_scores.shape) - len(attention_mask.shape)
):
attention_mask = ops.expand_dims(
attention_mask, axis=mask_expansion_axis
)
attention_mask = attention_mask[:, None, :, :]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC @mattdangerw and I had a conversation about it. Let's keep this as is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong feeling. The thing to keep in mind here is what is public API and what's internal to the model.

RotaryEmbedding is public, that's the one we want to support with multiple different call ranks/configurations.

Llama attention is unexposed, so it's ok to make assumptions about the input shape as long as it's valid for llama models.

return self._softmax(attention_scores, attention_mask)

def _compute_attention(self, query, key, value, attention_mask=None):
attention_scores = ops.einsum("aecd,abcd->acbe", key, query)

norm_factor = ops.sqrt(
ops.convert_to_tensor(self.attn_head_size, self.compute_dtype)
)

attention_scores /= norm_factor

attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_output = ops.einsum(
"acbe,aecd->abcd", attention_scores, value
)

return attention_output, attention_scores

def get_config(self):
config = super().get_config()
config.update(
{
"num_query_heads": self.num_query_heads,
"hidden_dim": self.hidden_dim,
"kernel_initializer": keras.initializers.serialize(
self.kernel_initializer
),
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"num_key_value_heads": self.num_key_value_heads,
"max_sequence_length": self.max_sequence_length,
}
)
return config
Loading