diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 7b898138fd..8fd6a70ac0 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -93,6 +93,12 @@ from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.models.llama.llama_backbone import LlamaBackbone from keras_nlp.models.mistral.mistral_backbone import MistralBackbone +from keras_nlp.models.mistral.mistral_causal_lm import MistralCausalLM +from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( + MistralCausalLMPreprocessor, +) +from keras_nlp.models.mistral.mistral_preprocessor import MistralPreprocessor +from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer from keras_nlp.models.opt.opt_backbone import OPTBackbone from keras_nlp.models.opt.opt_causal_lm import OPTCausalLM from keras_nlp.models.opt.opt_causal_lm_preprocessor import ( diff --git a/keras_nlp/models/mistral/mistral_attention.py b/keras_nlp/models/mistral/mistral_attention.py index 680f1f6d1b..1c58c59150 100644 --- a/keras_nlp/models/mistral/mistral_attention.py +++ b/keras_nlp/models/mistral/mistral_attention.py @@ -136,7 +136,6 @@ def call( cache_update_index=None, training=None, ): - seq_len = ops.shape(hidden_states)[1] start_index = ( cache_update_index if cache_update_index is not None else 0 ) @@ -148,89 +147,34 @@ def call( query = self._query_dense(hidden_states) - # Note that the original PyTorch implementation uses - # view_as_complex/view_as_real while we use split/concatenate to - # convert to/from complex numbers. The transformations below make - # the rope computation numerically equivalent to the original - # implementation. - def _mistral_rope(x): - x = ops.concatenate([x[..., ::2], x[..., 1::2]], axis=-1) - x = self.rotary_embedding_layer(x, start_index=start_index) - x = ops.reshape( - ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x) - ) - return x - # Compute RoPE for queries - query = _mistral_rope(query) + query = self.rotary_embedding_layer(query, start_index=start_index) def _compute_key_value(x): key, value = self._key_dense(x), self._value_dense(x) - key = _mistral_rope(key) + # Compute RoPE for keys + key = self.rotary_embedding_layer(key, start_index=start_index) return key, value if cache is not None: - cache_k = cache[:, 0, ...] - cache_v = cache[:, 1, ...] - + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + key_update, value_update = _compute_key_value(hidden_states) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: if cache_update_index is not None: - # Compute the new keys and values - key, value = _compute_key_value(hidden_states) - - # Cache is a rotating buffer, we want to warp around if - # the sequence length exceeds the sliding window. - update_end_index = ( - cache_update_index + seq_len - 1 - ) % self._sliding_window + 1 - update_end_index = ops.cast(update_end_index, "int32") - cache_update_index = cache_update_index % self._sliding_window - update_start_index = ops.cond( - update_end_index > cache_update_index, - lambda: ops.cast(cache_update_index, "int32"), - lambda: ops.cast(0, "int32"), - ) - # Also note that the update step below assumes that the - # sequence length is always one when `cache_update_index != 0`. - # This is necessary to support XLA compilation. Ideally, we - # would want to use - # `key[:, -(update_end_index - update_start_index):, ...]` - # as the update but updating using a dynamic slice gives an - # XLA compilation error in TensorFlow. - # Passing a sequence of length > 1 with cache update might give - # incorrect results (since there is no way to determine how - # many most recent tokens are to be saved if the tokens exceed - # the sliding window length). - cache_k = ops.slice_update( - cache_k, - [0, update_start_index, 0, 0], - # We slice the keys and values since if the user has passed - # a sequence of length > `self._sliding_window`. We want to - # prefill the cache using just the most recent values in the - # sliding window. - ops.cast( - key[:, -self._sliding_window :, ...], cache_k.dtype - ), + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" ) - cache_v = ops.slice_update( - cache_v, - [0, update_start_index, 0, 0], - ops.cast( - value[:, -self._sliding_window :, ...], cache_v.dtype - ), - ) - cache = ops.stack([cache_k, cache_v], axis=1) - - # Get the required keys and values from the cache. - # Since we expect the user to pass a fixed-size cache, we just - # pick the first few slices up-to and including the newly computed - # keys and values. - cache_k = cache_k[:, :update_end_index, ...] - cache_v = cache_v[:, :update_end_index, ...] - - key = ops.cast(cache_k, dtype=self.compute_dtype) - value = ops.cast(cache_v, dtype=self.compute_dtype) - else: - # Compute keys and values key, value = _compute_key_value(hidden_states) # [batch_shape, seq_len, num_key_value_heads, head_dim] @@ -260,7 +204,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None): return self._softmax(attention_scores) def _compute_attention(self, query, key, value, attention_mask=None): - attention_scores = ops.einsum(self._dot_product_equation, key, query) + attention_scores = ops.einsum(self._dot_product_equation, query, key) norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype)) diff --git a/keras_nlp/models/mistral/mistral_causal_lm.py b/keras_nlp/models/mistral/mistral_causal_lm.py new file mode 100644 index 0000000000..22defbc456 --- /dev/null +++ b/keras_nlp/models/mistral/mistral_causal_lm.py @@ -0,0 +1,213 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.models.mistral.mistral_backbone import MistralBackbone +from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( + MistralCausalLMPreprocessor, +) +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.MistralCausalLM") +class MistralCausalLM(GenerativeTask): + """An end-to-end Mistral model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a GPT-NeoX model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_nlp.samplers` objects to control the generation. By + default, `"top_k"` sampling will be used. + + Args: + backbone: A `keras_nlp.models.MistralBackbone` instance. + preprocessor: A `keras_nlp.models.MistralCausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + """ + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.inputs + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Default compilation === + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(2e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], + jit_compile=True, + ) + + @classproperty + def backbone_cls(cls): + return MistralBackbone + + @classproperty + def preprocessor_cls(cls): + return MistralCausalLMPreprocessor + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `MistralCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + # Each decoder layer has a cache; we update them separately. + updated_cache = [] + for i in range(self.backbone.num_layers): + current_cache = cache[:, i, ...] + x, next_cache = self.backbone.transformer_layers[i]( + x, + self_attention_cache=current_cache, + self_attention_cache_update_index=cache_update_index, + ) + updated_cache.append(next_cache) + cache = ops.stack(updated_cache, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads + shape = [ + batch_size, + num_layers, + 2, + max_length, + num_key_value_heads, + head_dim, + ] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + end_token_id=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + end_token_id: The id of the end token to stop on. If all + sequences have produced a new `end_token_id`, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self._sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + end_token_id=end_token_id, + hidden_states=hidden_states, + ) + + # Compute an output padding mask with the token ids we updated. + if end_token_id is not None: + # Build a mask of `end_token_id` locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = ops.logical_and( + ops.equal(token_ids, end_token_id), + ops.logical_not(padding_mask), + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } diff --git a/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py new file mode 100644 index 0000000000..c8a0821733 --- /dev/null +++ b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py @@ -0,0 +1,171 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import ops +from keras_nlp.models.mistral.mistral_preprocessor import MistralPreprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.MistralCausalLMPreprocessor") +class MistralCausalLMPreprocessor(MistralPreprocessor): + """Mistral Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.MistralCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.MistralCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.MistralTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.MistralCausalLMPreprocessor.from_preset( + "mistral_base_en" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("League of legends") + preprocessor(sentence) + # Same output. + preprocessor("League of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) + preprocessor(sentences) + # Same output. + preprocessor(["Taco tuesday", "Fish taco please!"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`MistralCausalLMPreprocessor` generates `y` and " + "`sample_weight` based on your input data, but your data " + "already contains `y` or `sample_weight`. Your `y` and " + "`sample_weight` will be ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + """Covert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + # Convert the inputs to numpy arrays if they aren't a tensor already. + if not isinstance(token_ids, tf.Tensor): + token_ids = ops.convert_to_numpy(token_ids) + # Make sure the numpy array has type `int32` since + # `SentencePieceProcessor.detokenize` only accepts `int32` arrays. + token_ids = token_ids.astype("int32") + if not isinstance(padding_mask, tf.Tensor): + padding_mask = ops.convert_to_numpy(padding_mask) + padding_mask = padding_mask.astype("bool") + # Strip any special tokens during detokenization (e.g. the start and + # end markers). In the future we could make this configurable. + padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) + padding_mask = padding_mask & ( + token_ids != self.tokenizer.start_token_id + ) + token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/mistral/mistral_causal_lm_preprocessor_test.py b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..420995016b --- /dev/null +++ b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor_test.py @@ -0,0 +1,81 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( + MistralCausalLMPreprocessor, +) +from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_nlp.tests.test_case import TestCase + + +class MistralCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = MistralTokenizer( + # Generated using create_mistral_test_proto.py + proto=os.path.join( + self.get_test_data_dir(), "mistral_test_vocab.spm" + ) + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = (["the quick brown fox"],) + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=MistralCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + }, + [[3, 8, 4, 6, 0, 0, 0, 0]], # Pass through labels. + [[1, 1, 1, 1, 0, 0, 0, 0]], # Pass through sample_weights. + ), + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + + preprocessor = MistralCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[3, 8, 4, 6, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[8, 4, 6, 0, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "the quick brown fox" + preprocessor = MistralCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 3, 8, 4, 6, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 3, 8, 4, 6, 0, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0], + } + preprocessor = MistralCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the quick brown fox") diff --git a/keras_nlp/models/mistral/mistral_causal_lm_test.py b/keras_nlp/models/mistral/mistral_causal_lm_test.py new file mode 100644 index 0000000000..3f9d7fab36 --- /dev/null +++ b/keras_nlp/models/mistral/mistral_causal_lm_test.py @@ -0,0 +1,130 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import patch + +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.mistral.mistral_backbone import MistralBackbone +from keras_nlp.models.mistral.mistral_causal_lm import MistralCausalLM +from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( + MistralCausalLMPreprocessor, +) +from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_nlp.tests.test_case import TestCase + + +class MistralCausalLMTest(TestCase): + def setUp(self): + self.preprocessor = MistralCausalLMPreprocessor( + MistralTokenizer( + # Generated using create_mistral_test_proto.py + proto=os.path.join( + self.get_test_data_dir(), "mistral_test_vocab.spm" + ) + ), + sequence_length=8, + ) + self.backbone = MistralBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=8, + intermediate_dim=16, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = (["the quick brown fox", "the earth is round"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=MistralCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, 10), + ) + + def test_generate(self): + causal_lm = MistralCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :5], + prompt_ids["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + prompt_ids["padding_mask"][:, :5], + ) + + def test_early_stopping(self): + causal_lm = MistralCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the earth"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = MistralCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the quick brown fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the quick brown fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MistralCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in MistralCausalLM.presets: + self.run_preset_test( + cls=MistralCausalLM, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/mistral/mistral_transformer_decoder.py b/keras_nlp/models/mistral/mistral_transformer_decoder.py index 9b6f7fdbf8..65b8a3708d 100644 --- a/keras_nlp/models/mistral/mistral_transformer_decoder.py +++ b/keras_nlp/models/mistral/mistral_transformer_decoder.py @@ -36,7 +36,7 @@ def __init__( num_key_value_heads, rope_max_wavelength=10000, rope_scaling_factor=1.0, - activation="relu", + activation="silu", layer_norm_epsilon=1e-5, kernel_initializer="glorot_uniform", sliding_window=512, @@ -145,6 +145,8 @@ def call( decoder_sequence=decoder_sequence, decoder_padding_mask=decoder_padding_mask, decoder_attention_mask=decoder_attention_mask, + self_attention_cache=self_attention_cache, + self_attention_cache_update_index=self_attention_cache_update_index, ) residual = decoder_sequence @@ -184,23 +186,36 @@ def _compute_self_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask, + self_attention_cache, + self_attention_cache_update_index, ): decoder_mask = merge_padding_and_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask ) batch_size = ops.shape(decoder_sequence)[0] input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. + if self_attention_cache is not None: + input_length = ops.shape(self_attention_cache)[2] + + cache_update_index = ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ) # Mistral uses a banded attention mask causal_mask_lower = compute_causal_mask( - batch_size, input_length, output_length, 0 + batch_size, input_length, output_length, cache_update_index ) # Below is a workaround for `ops.triu` for Keras 2. # TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is removed. # causal_mask = ops.triu(causal_mask_lower, k=-self.sliding_window) - i = ops.arange(output_length)[:, None] + i = ops.arange(output_length)[:, None] + cache_update_index j = ops.arange(input_length)[None, :] - causal_mask_upper = ops.cast(i <= j + self.sliding_window, "int32") + causal_mask_upper = ops.cast(i < j + self.sliding_window, "int32") causal_mask = ops.minimum(causal_mask_lower, causal_mask_upper) return (