-
Notifications
You must be signed in to change notification settings - Fork 306
Add MistralAI's 7B Transformer as a backbone in KerasNLP Models #1314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1bc5d65
f2fb781
1931ab1
b14304b
cd30279
c51261d
301cf0f
97f11e4
c3d71c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright 2023 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,293 @@ | ||
| # Copyright 2023 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from keras_nlp.backend import keras | ||
| from keras_nlp.backend import ops | ||
| from keras_nlp.layers.modeling.rotary_embedding import RotaryEmbedding | ||
| from keras_nlp.utils.keras_utils import clone_initializer | ||
|
|
||
|
|
||
| # This is just a self-attention layer in Mistral. But it can be generalized | ||
| # to use the `keras_nlp.layers.CachedMultiHeadAttention` API. Since this layer | ||
| # implements grouped-query attention and sliding window attention, it might be | ||
| # useful outside of Mistral itself. | ||
| # TODO(tirthasheshpatel): Generalize the attention layer | ||
| # TODO(tirthasheshpatel): Merge `LlamaAttention` with this layer | ||
| # TODO(tirthasheshpatel): Use flash attention | ||
| class CachedMistralAttention(keras.layers.Layer): | ||
| """A cached grounded query attention layer with sliding window.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| num_query_heads, | ||
| num_key_value_heads, | ||
| rope_max_wavelength=10000, | ||
| rope_scaling_factor=1.0, | ||
| kernel_initializer="glorot_uniform", | ||
| sliding_window=512, | ||
| dropout=0, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self._num_query_heads = num_query_heads | ||
| self._num_key_value_heads = num_key_value_heads | ||
| self._sliding_window = sliding_window | ||
| self._dropout = dropout | ||
|
|
||
| self._num_key_value_groups = num_query_heads // num_key_value_heads | ||
| self._rope_max_wavelength = rope_max_wavelength | ||
|
|
||
| self._kernel_initializer = keras.initializers.get( | ||
| clone_initializer(kernel_initializer) | ||
| ) | ||
|
|
||
| self._rope_scaling_factor = rope_scaling_factor | ||
|
|
||
| def build(self, inputs_shape): | ||
| # Einsum variables: | ||
| # b = batch size | ||
| # q = query length | ||
| # k = key/value length | ||
| # m = model dim | ||
| # u = num query heads | ||
| # v = num key/value heads | ||
| # h = head dim | ||
| self._hidden_dim = inputs_shape[-1] | ||
| self._head_dim = self._hidden_dim // self._num_query_heads | ||
|
|
||
| self._query_dense = keras.layers.EinsumDense( | ||
| equation="bqm,muh->bquh", | ||
| output_shape=(None, self._num_query_heads, self._head_dim), | ||
| kernel_initializer=self._kernel_initializer, | ||
| dtype=self.compute_dtype, | ||
| name="query", | ||
| ) | ||
| self._query_dense.build(inputs_shape) | ||
|
|
||
| self._key_dense = keras.layers.EinsumDense( | ||
| equation="bkm,mvh->bkvh", | ||
| output_shape=( | ||
| None, | ||
| self._num_key_value_heads, | ||
| self._head_dim, | ||
| ), | ||
| kernel_initializer=self._kernel_initializer, | ||
| dtype=self.compute_dtype, | ||
| name="key", | ||
| ) | ||
| self._key_dense.build(inputs_shape) | ||
|
|
||
| self._value_dense = keras.layers.EinsumDense( | ||
| equation="bkm,mvh->bkvh", | ||
| output_shape=( | ||
| None, | ||
| self._num_key_value_heads, | ||
| self._head_dim, | ||
| ), | ||
| kernel_initializer=self._kernel_initializer, | ||
| dtype=self.compute_dtype, | ||
| name="value", | ||
| ) | ||
| self._value_dense.build(inputs_shape) | ||
|
|
||
| self._softmax = keras.layers.Softmax(axis=-1, name="attention_softmax") | ||
|
|
||
| self._dropout_layer = keras.layers.Dropout( | ||
| rate=self._dropout, dtype=self.compute_dtype | ||
| ) | ||
|
|
||
| self._output_dense = keras.layers.EinsumDense( | ||
| equation="bquh,uhm->bqm", | ||
| output_shape=(None, self._hidden_dim), | ||
| kernel_initializer=self._kernel_initializer, | ||
| dtype=self.compute_dtype, | ||
| name="attention_output", | ||
| ) | ||
| self._output_dense.build( | ||
| (None, None, self._num_query_heads, self._head_dim) | ||
| ) | ||
|
|
||
| self.rotary_embedding_layer = RotaryEmbedding( | ||
| max_wavelength=self._rope_max_wavelength, | ||
| scaling_factor=self._rope_scaling_factor, | ||
| dtype=self.compute_dtype, | ||
| ) | ||
|
|
||
| self._dot_product_equation = "bquh,bkuh->buqk" | ||
| self._combine_equation = "buqk,bkuh->bquh" | ||
|
|
||
| self.built = True | ||
|
|
||
| def call( | ||
| self, | ||
| hidden_states, | ||
| attention_mask=None, | ||
| cache=None, | ||
| cache_update_index=None, | ||
|
Comment on lines
+135
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: Right now, caching doesn't work when the sequence length is greater than the sliding window. Can be addressed as a follow-up when adding the Generator model; shouldn't be a blocker here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. If the upstream version is not solving this correctly, let's not worry too much about this. |
||
| training=None, | ||
| ): | ||
| seq_len = ops.shape(hidden_states)[1] | ||
| start_index = ( | ||
| cache_update_index if cache_update_index is not None else 0 | ||
| ) | ||
| # If `cache_update_index` is a tensor, RotaryEmbedding expects it | ||
| # to have dtype `self.compute_dtype`. | ||
| start_index = ops.cast( | ||
| start_index, self.rotary_embedding_layer.compute_dtype | ||
| ) | ||
|
|
||
| query = self._query_dense(hidden_states) | ||
|
|
||
| # Note that the original PyTorch implementation uses | ||
| # view_as_complex/view_as_real while we use split/concatenate to | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain this a bit more? Why do we need to consider complex numbers here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's the mistral source for computing frequencies (same as Llama 2) and computing the embeddings (same as Llama 2 too) The frequencies and inputs are treated as complex numbers and the computation follows the "Theoretical Explanation" section in the paper. PyTorch's This is the only fundamental difference in both the computations. We can get the same results if we shuffle the inputs such that the alternate elements get moved to the end of the tensor. Hence, the The reverse transformation exactly mirrors/undoes what we did above. Code demonstration of the above explainationimport torch
import numpy as np
from keras import ops
def _reshape_for_broadcast(freqs_cis, x):
"""
freqs_cis: complex - (seq_len, head_dim / 2)
x: complex - (bsz, seq_len, head_dim / 2)
"""
ndim = x.ndim
assert 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
freqs_cis.shape,
(x.shape[1], x.shape[-1]),
)
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
# Llama's version of rotary embeddings
def apply_rotary_emb(
xq,
freqs_cis,
):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
return xq_out.type_as(xq)
# Our version of the same computation.
# With transformations to match the `apply_rotary_emb` function above.
def apply_rotary_pos_emb(tensor, cos_emb, sin_emb):
tensor = ops.concatenate((tensor[..., ::2], tensor[..., 1::2]), axis=-1)
x1, x2 = ops.split(tensor, 2, axis=-1)
half_rot_tensor = ops.concatenate((-x2, x1), axis=-1)
res = (tensor * cos_emb) + (half_rot_tensor * sin_emb)
return res
x = np.random.standard_normal((1,2,1,16))
cos_emb = np.random.standard_normal((2,8))
sin_emb = np.random.standard_normal((2,8))
print(x)
print(ops.concatenate((x[..., ::2], x[..., 1::2]), axis=-1))
print(np.split(np.concatenate((x[..., ::2], x[..., 1::2]), axis=-1), 2, axis=-1))
print(torch.view_as_complex(torch.tensor(x).reshape(*x.shape[:-1], -1, 2)))
print(apply_rotary_emb(torch.tensor(x), torch.tensor(cos_emb + sin_emb * 1.0j)))
y = apply_rotary_pos_emb(
x,
np.concatenate([cos_emb[None, :, None, :]]*2, axis=-1),
np.concatenate([sin_emb[None, :, None, :]]*2, axis=-1)
)
print(ops.reshape(ops.stack(ops.split(y, 2, axis=-1), axis=-1), (y.shape[0], y.shape[1], y.shape[2], -1)))A bit complicated but it should be possible to achieve the same behavior by shuffling the weights using the same transformations. I believe that's what the huggingface folks have done which is why this isn't required in the Llama backbone PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks!
Shuffling what weights? At the highest level, we should just consider whether we should pull this into the lower level RotaryEmbedding layer. We want it to be useful for the most common use cases of rotary embeddings. |
||
| # convert to/from complex numbers. The transformations below make | ||
| # the rope computation numerically equivalent to the original | ||
| # implementation. | ||
| def _mistral_rope(x): | ||
| x = ops.concatenate([x[..., ::2], x[..., 1::2]], axis=-1) | ||
| x = self.rotary_embedding_layer(x, start_index=start_index) | ||
| x = ops.reshape( | ||
| ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x) | ||
| ) | ||
| return x | ||
|
|
||
| # Compute RoPE for queries | ||
| query = _mistral_rope(query) | ||
|
|
||
| def _compute_key_value(x): | ||
| key, value = self._key_dense(x), self._value_dense(x) | ||
| key = _mistral_rope(key) | ||
| return key, value | ||
|
|
||
| if cache is not None: | ||
| cache_k = cache[:, 0, ...] | ||
| cache_v = cache[:, 1, ...] | ||
|
|
||
| if cache_update_index is not None: | ||
| # Compute the new keys and values | ||
| key, value = _compute_key_value(hidden_states) | ||
|
|
||
| # Cache is a rotating buffer, we want to warp around if | ||
| # the sequence length exceeds the sliding window. | ||
| update_end_index = ( | ||
| cache_update_index + seq_len - 1 | ||
| ) % self._sliding_window + 1 | ||
| update_end_index = ops.cast(update_end_index, "int32") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a general note, torch and jax like int32 on gpu, but tensorflow has limited op support with int32 (and does better with int64). We probably don't have coverage for this code path on GPU on TF with an accelerator yet, but might come up down the line. |
||
| cache_update_index = cache_update_index % self._sliding_window | ||
| update_start_index = ops.cond( | ||
| update_end_index > cache_update_index, | ||
| lambda: ops.cast(cache_update_index, "int32"), | ||
| lambda: ops.cast(0, "int32"), | ||
| ) | ||
| # Also note that the update step below assumes that the | ||
| # sequence length is always one when `cache_update_index != 0`. | ||
| # This is necessary to support XLA compilation. Ideally, we | ||
| # would want to use | ||
| # `key[:, -(update_end_index - update_start_index):, ...]` | ||
| # as the update but updating using a dynamic slice gives an | ||
| # XLA compilation error in TensorFlow. | ||
| # Passing a sequence of length > 1 with cache update might give | ||
| # incorrect results (since there is no way to determine how | ||
| # many most recent tokens are to be saved if the tokens exceed | ||
| # the sliding window length). | ||
| cache_k = ops.slice_update( | ||
| cache_k, | ||
| [0, update_start_index, 0, 0], | ||
| # We slice the keys and values since if the user has passed | ||
| # a sequence of length > `self._sliding_window`. We want to | ||
| # prefill the cache using just the most recent values in the | ||
| # sliding window. | ||
| ops.cast( | ||
| key[:, -self._sliding_window :, ...], cache_k.dtype | ||
| ), | ||
| ) | ||
| cache_v = ops.slice_update( | ||
| cache_v, | ||
| [0, update_start_index, 0, 0], | ||
| ops.cast( | ||
| value[:, -self._sliding_window :, ...], cache_v.dtype | ||
| ), | ||
| ) | ||
| cache = ops.stack([cache_k, cache_v], axis=1) | ||
|
|
||
| # Get the required keys and values from the cache. | ||
| # Since we expect the user to pass a fixed-size cache, we just | ||
| # pick the first few slices up-to and including the newly computed | ||
| # keys and values. | ||
| cache_k = cache_k[:, :update_end_index, ...] | ||
| cache_v = cache_v[:, :update_end_index, ...] | ||
|
Comment on lines
+227
to
+228
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. JAX fails here if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, this looks unsupported by XLA today, as this would involve dynamic shapes in a compiled |
||
|
|
||
| key = ops.cast(cache_k, dtype=self.compute_dtype) | ||
| value = ops.cast(cache_v, dtype=self.compute_dtype) | ||
| else: | ||
| # Compute keys and values | ||
| key, value = _compute_key_value(hidden_states) | ||
|
|
||
| # [batch_shape, seq_len, num_key_value_heads, head_dim] | ||
| # -> [batch_shape, seq_len, num_heads, head_dim] | ||
| key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2) | ||
| value = ops.repeat(value, repeats=self._num_key_value_groups, axis=2) | ||
|
|
||
| attention_output = self._compute_attention( | ||
| query, key, value, attention_mask | ||
| ) | ||
|
|
||
| attention_output = self._dropout_layer( | ||
| attention_output, training=training | ||
| ) | ||
|
|
||
| attention_output = self._output_dense(attention_output) | ||
|
|
||
| if cache is not None: | ||
| return attention_output, cache | ||
| return attention_output | ||
|
|
||
| def _masked_softmax(self, attention_scores, attention_mask=None): | ||
| if attention_mask is not None: | ||
| return self._softmax( | ||
| attention_scores, attention_mask[:, None, :, :] | ||
| ) | ||
| return self._softmax(attention_scores) | ||
|
|
||
| def _compute_attention(self, query, key, value, attention_mask=None): | ||
| attention_scores = ops.einsum(self._dot_product_equation, key, query) | ||
|
|
||
| norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype)) | ||
|
|
||
| attention_scores = attention_scores / norm_factor | ||
|
|
||
| attention_scores = self._masked_softmax( | ||
| attention_scores, attention_mask | ||
| ) | ||
| attention_output = ops.einsum( | ||
| self._combine_equation, attention_scores, value | ||
| ) | ||
|
|
||
| return attention_output | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "num_query_heads": self._num_query_heads, | ||
| "num_key_value_heads": self._num_key_value_heads, | ||
| "rope_max_wavelength": self._rope_max_wavelength, | ||
| "rope_scaling_factor": self._rope_scaling_factor, | ||
| "kernel_initializer": keras.initializers.serialize( | ||
| self._kernel_initializer | ||
| ), | ||
| "sliding_window": self._sliding_window, | ||
| "dropout": self._dropout, | ||
| } | ||
| ) | ||
| return config | ||
Uh oh!
There was an error while loading. Please reload this page.