diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py index e238b668e9..8c15de8574 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py @@ -266,10 +266,8 @@ def generate_postprocess( x["decoder_token_ids"], x["decoder_padding_mask"], ) - if not isinstance(decoder_token_ids, tf.Tensor): - decoder_token_ids = ops.convert_to_numpy(decoder_token_ids) - if not isinstance(decoder_padding_mask, tf.Tensor): - decoder_padding_mask = ops.convert_to_numpy(decoder_padding_mask) + decoder_token_ids = ops.convert_to_numpy(decoder_token_ids) + decoder_padding_mask = ops.convert_to_numpy(decoder_padding_mask) # Strip any special tokens during detokenization, i.e., the start and # end markers. In the future, we could make this configurable. decoder_padding_mask = ( diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py index 37493bb91d..f67dab70a0 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py @@ -14,7 +14,6 @@ import pytest -from keras_nlp.backend import ops from keras_nlp.models.bart.bart_seq_2_seq_lm_preprocessor import ( BartSeq2SeqLMPreprocessor, ) @@ -82,8 +81,8 @@ def test_generate_preprocess(self): def test_generate_postprocess(self): preprocessor = BartSeq2SeqLMPreprocessor(**self.init_kwargs) input_data = { - "decoder_token_ids": ops.array([0, 4, 5, 6, 2], dtype="int32"), - "decoder_padding_mask": ops.array([1, 1, 1, 1, 1], dtype="bool"), + "decoder_token_ids": [0, 4, 5, 6, 2], + "decoder_padding_mask": [1, 1, 1, 1, 1], } output = preprocessor.generate_postprocess(input_data) self.assertAllEqual(output, " airplane at") diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py index b501ad3fe0..41ea591df8 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -163,10 +163,8 @@ def generate_postprocess( back to a string. """ token_ids, padding_mask = x["token_ids"], x["padding_mask"] - if not isinstance(token_ids, tf.Tensor): - token_ids = ops.convert_to_numpy(token_ids) - if not isinstance(padding_mask, tf.Tensor): - padding_mask = ops.convert_to_numpy(padding_mask) + token_ids = ops.convert_to_numpy(token_ids) + padding_mask = ops.convert_to_numpy(padding_mask) # Strip any special tokens during detokenization (e.g. the start and # end markers). In the future we could make this configurable. padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py index b0cdd2e3ee..400273b792 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest -import tensorflow as tf from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( GPT2CausalLMPreprocessor, @@ -78,8 +77,8 @@ def test_generate_preprocess(self): def test_generate_postprocess(self): input_data = { - "token_ids": tf.constant([6, 1, 3, 4, 2, 5, 0, 0]), - "padding_mask": tf.cast([1, 1, 1, 1, 1, 1, 0, 0], dtype="bool"), + "token_ids": [6, 1, 3, 4, 2, 5, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0], } preprocessor = GPT2CausalLMPreprocessor(**self.init_kwargs) x = preprocessor.generate_postprocess(input_data) diff --git a/keras_nlp/models/opt/opt_causal_lm_preprocessor.py b/keras_nlp/models/opt/opt_causal_lm_preprocessor.py index 26f01a32d1..9cc8c7f495 100644 --- a/keras_nlp/models/opt/opt_causal_lm_preprocessor.py +++ b/keras_nlp/models/opt/opt_causal_lm_preprocessor.py @@ -164,10 +164,8 @@ def generate_postprocess( back to a string. """ token_ids, padding_mask = x["token_ids"], x["padding_mask"] - if not isinstance(token_ids, tf.Tensor): - token_ids = ops.convert_to_numpy(token_ids) - if not isinstance(padding_mask, tf.Tensor): - padding_mask = ops.convert_to_numpy(padding_mask) + token_ids = ops.convert_to_numpy(token_ids) + padding_mask = ops.convert_to_numpy(padding_mask) # Strip any special tokens during detokenization (e.g. the start and # end markers). In the future we could make this configurable. padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) diff --git a/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py b/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py index 2f225612d4..9ba6851d4b 100644 --- a/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest -import tensorflow as tf from keras_nlp.models.opt.opt_causal_lm_preprocessor import ( OPTCausalLMPreprocessor, @@ -77,8 +76,8 @@ def test_generate_preprocess(self): def test_generate_postprocess(self): input_data = { - "token_ids": tf.constant([1, 2, 4, 5, 3, 6, 0, 0]), - "padding_mask": tf.cast([1, 1, 1, 1, 1, 1, 0, 0], dtype="bool"), + "token_ids": [1, 2, 4, 5, 3, 6, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0], } preprocessor = OPTCausalLMPreprocessor(**self.init_kwargs) x = preprocessor.generate_postprocess(input_data) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index f92d9e6a77..133c9565b0 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -561,7 +561,7 @@ def process_unseen_tokens(): def detokenize(self, inputs): inputs, unbatched, _ = convert_to_ragged_batch(inputs) - + inputs = tf.cast(inputs, self.dtype) unicode_text = tf.strings.reduce_join( self.id_to_token_map.lookup(inputs), axis=-1 )