diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py index c0b91abb3c..5e7873a455 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py @@ -156,9 +156,29 @@ def tokenize(self, inputs): # Shift the tokens IDs right by one. return tf.add(tokens, 1) - def detokenize(self, ids): - ids = tf.ragged.boolean_mask(ids, tf.not_equal(ids, self.mask_token_id)) - return super().detokenize(ids) + def detokenize(self, inputs): + if inputs.dtype == tf.string: + return super().detokenize(inputs) + + tokens = tf.ragged.boolean_mask( + inputs, tf.not_equal(inputs, self.mask_token_id) + ) + + # Shift the tokens IDs left by one. + tokens = tf.subtract(tokens, 1) + + # Correct `unk_token_id`, `end_token_id`, `start_token_id`, respectively. + # Note: The `pad_token_id` is taken as 0 (`unk_token_id`) since the + # proto does not contain `pad_token_id`. This mapping of the pad token + # is done automatically by the above subtraction. + tokens = tf.where(tf.equal(tokens, self.unk_token_id - 1), 0, tokens) + tokens = tf.where(tf.equal(tokens, self.end_token_id - 1), 2, tokens) + tokens = tf.where(tf.equal(tokens, self.start_token_id - 1), 1, tokens) + + # Note: Even though we map `"" and `""` to the correct IDs, + # the `detokenize` method will return empty strings for these tokens. + # This is a vagary of the `sentencepiece` library. + return super().detokenize(tokens) @classproperty def presets(cls):