Skip to content

Commit f1ab62a

Browse files
authored
Fix mistral and electra tokenizer to match kaggle changes (#1387)
We are changing all tokenizer to store vocabularies via assets (and not in the config). This requires some changes to tokenizer so files state can be set after object creation.
1 parent 401e569 commit f1ab62a

File tree

2 files changed

+43
-30
lines changed

2 files changed

+43
-30
lines changed

keras_nlp/models/electra/electra_tokenizer.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,31 @@ class ElectraTokenizer(WordPieceTokenizer):
5858
"""
5959

6060
def __init__(self, vocabulary, lowercase=False, **kwargs):
61+
self.cls_token = "[CLS]"
62+
self.sep_token = "[SEP]"
63+
self.pad_token = "[PAD]"
64+
self.mask_token = "[MASK]"
6165
super().__init__(vocabulary=vocabulary, lowercase=lowercase, **kwargs)
6266

63-
# Check for special tokens
64-
cls_token = "[CLS]"
65-
sep_token = "[SEP]"
66-
pad_token = "[PAD]"
67-
mask_token = "[MASK]"
68-
69-
for token in [cls_token, pad_token, sep_token, mask_token]:
70-
if token not in self.get_vocabulary():
71-
raise ValueError(
72-
f"Cannot find token `'{token}'` in the provided "
73-
f"`vocabulary`. Please provide `'{token}'` in your "
74-
"`vocabulary` or use a pretrained `vocabulary` name."
75-
)
76-
self.cls_token_id = self.token_to_id(cls_token)
77-
self.sep_token_id = self.token_to_id(sep_token)
78-
self.pad_token_id = self.token_to_id(pad_token)
79-
self.mask_token_id = self.token_to_id(mask_token)
67+
def set_vocabulary(self, vocabulary):
68+
super().set_vocabulary(vocabulary)
69+
70+
if vocabulary is not None:
71+
# Check for necessary special tokens.
72+
for token in [self.cls_token, self.pad_token, self.sep_token]:
73+
if token not in self.vocabulary:
74+
raise ValueError(
75+
f"Cannot find token `'{token}'` in the provided "
76+
f"`vocabulary`. Please provide `'{token}'` in your "
77+
"`vocabulary` or use a pretrained `vocabulary` name."
78+
)
79+
80+
self.cls_token_id = self.token_to_id(self.cls_token)
81+
self.sep_token_id = self.token_to_id(self.sep_token)
82+
self.pad_token_id = self.token_to_id(self.pad_token)
83+
self.mask_token_id = self.token_to_id(self.mask_token)
84+
else:
85+
self.cls_token_id = None
86+
self.sep_token_id = None
87+
self.pad_token_id = None
88+
self.mask_token_id = None

keras_nlp/models/mistral/mistral_tokenizer.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,22 @@ class MistralTokenizer(SentencePieceTokenizer):
5858
"""
5959

6060
def __init__(self, proto, **kwargs):
61+
self.start_token = "<s>"
62+
self.end_token = "</s>"
6163
super().__init__(proto=proto, **kwargs)
6264

63-
# Check for necessary special tokens.
64-
start_token = "<s>"
65-
end_token = "</s>"
66-
for token in [start_token, end_token]:
67-
if token not in self.get_vocabulary():
68-
raise ValueError(
69-
f"Cannot find token `'{token}'` in the provided "
70-
f"`vocabulary`. Please provide `'{token}'` in your "
71-
"`vocabulary` or use a pretrained `vocabulary` name."
72-
)
73-
74-
self.start_token_id = self.token_to_id(start_token)
75-
self.end_token_id = self.token_to_id(end_token)
65+
def set_proto(self, proto):
66+
super().set_proto(proto)
67+
if proto is not None:
68+
for token in [self.start_token, self.end_token]:
69+
if token not in self.get_vocabulary():
70+
raise ValueError(
71+
f"Cannot find token `'{token}'` in the provided "
72+
f"`vocabulary`. Please provide `'{token}'` in your "
73+
"`vocabulary` or use a pretrained `vocabulary` name."
74+
)
75+
self.start_token_id = self.token_to_id(self.start_token)
76+
self.end_token_id = self.token_to_id(self.end_token)
77+
else:
78+
self.start_token_id = None
79+
self.end_token_id = None

0 commit comments

Comments
 (0)