Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
62eaf03
Start landing code for Kaggle integration (#1320)
mattdangerw Nov 20, 2023
21fb04c
Switch byte pair tokenizer to save_assets/load_assets (#1322)
mattdangerw Nov 21, 2023
0e3c674
Convert SentencePieceTokenizer and associated models to new assets pa…
nkovela1 Nov 21, 2023
3619a6a
Add tests for Presets workflow, Add Metadata (#1326)
nkovela1 Nov 23, 2023
38806fd
Automatically add the keras framework to kaggle handles (#1331)
mattdangerw Nov 29, 2023
e0d34dc
Fix a failing byte pair tokenizer test (#1336)
mattdangerw Nov 30, 2023
0820d62
Use set comparison for assets (#1335)
mattdangerw Nov 30, 2023
c4b0c3c
Fix whisper tokenizer saving (#1334)
mattdangerw Nov 30, 2023
e3f8d06
Remove special case Bart from_preset (#1333)
mattdangerw Nov 30, 2023
dbb6487
Fix t5 tokenizer presets (#1339)
mattdangerw Nov 30, 2023
6130253
Script to convert presets (#1340)
mattdangerw Nov 30, 2023
814959b
Switch all preset to the new Kaggle format (#1338)
mattdangerw Dec 1, 2023
2aced24
Let kagglehub select latest version (#1342)
mattdangerw Dec 4, 2023
245b7e9
Use the proper title for example (#1346)
Philmod Dec 5, 2023
6ad8a30
Update conversion script (#1347)
mattdangerw Dec 6, 2023
7cc4323
Improve preset error messages (#1349)
mattdangerw Dec 7, 2023
9cc8110
Use subclass checking check_preset_class (#1344)
mattdangerw Dec 7, 2023
4606f32
Add a hacky fix for TF 2.13 and 2.14 weights.h5 loading (#1353)
mattdangerw Dec 7, 2023
9cb5838
Another fix for saving on Keras 2 (#1354)
mattdangerw Dec 7, 2023
039ff45
Switch our preset to there final kaggle location (#1345)
mattdangerw Dec 7, 2023
9cc3f84
Fix rebase issue in bytepair tokenizer (#1366)
nkovela1 Dec 12, 2023
6f7f9a0
Change encoding to utf-8 to fix Kaggle branch test failure for PyTorc…
sampathweb Dec 13, 2023
ddfca77
Fix GPU test issue with Keras 2 (#1368)
nkovela1 Dec 14, 2023
0e43f09
Add in-place modification of file keys for backwards compatibility (#…
nkovela1 Dec 15, 2023
4d84eb1
Add file renaming logic for modification (#1370)
nkovela1 Dec 16, 2023
29a0ae5
Fix task pre-processor in tasks (#1373)
sampathweb Dec 20, 2023
401e569
Backwards compatible fix for functional model saving (#1378)
mattdangerw Jan 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Switch byte pair tokenizer to save_assets/load_assets (#1322)
As part of this work, we need to also switch all downstream
preprocessing layers to create packers on build (instead of on call).
  • Loading branch information
mattdangerw committed Jan 4, 2024
commit 21fb04ce753f0e05b1fb424beb7f3b19a404a5b3
39 changes: 24 additions & 15 deletions keras_nlp/models/bart/bart_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,23 @@ def __init__(
):
super().__init__(**kwargs)
self.tokenizer = tokenizer
self.encoder_sequence_length = encoder_sequence_length
self.decoder_sequence_length = decoder_sequence_length
self.encoder_packer = None
self.decoder_packer = None

def build(self, input_shape):
# Defer packer creation to `build()` so that we can be sure tokenizer
# assets have loaded when restoring a saved model.

# TODO: Use `MultiSegmentPacker` instead of `StartEndPacker` once we
# want to move to multi-segment packing and have improved
# `MultiSegmentPacker`'s performance.
self.encoder_packer = StartEndPacker(
start_value=tokenizer.start_token_id,
end_value=tokenizer.end_token_id,
pad_value=tokenizer.pad_token_id,
sequence_length=encoder_sequence_length,
start_value=self.tokenizer.start_token_id,
end_value=self.tokenizer.end_token_id,
pad_value=self.tokenizer.pad_token_id,
sequence_length=self.encoder_sequence_length,
return_padding_mask=True,
)

Expand All @@ -161,19 +169,10 @@ def __init__(
],
end_value=self.tokenizer.end_token_id,
pad_value=self.tokenizer.pad_token_id,
sequence_length=decoder_sequence_length,
sequence_length=self.decoder_sequence_length,
return_padding_mask=True,
)

def get_config(self):
config = super().get_config()
config.update(
{
"encoder_sequence_length": self.encoder_packer.sequence_length,
"decoder_sequence_length": self.decoder_packer.sequence_length,
}
)
return config
self.built = True

def call(self, x, y=None, sample_weight=None):
if not (
Expand Down Expand Up @@ -217,6 +216,16 @@ def call(self, x, y=None, sample_weight=None):

return pack_x_y_sample_weight(x, y, sample_weight)

def get_config(self):
config = super().get_config()
config.update(
{
"encoder_sequence_length": self.encoder_sequence_length,
"decoder_sequence_length": self.decoder_sequence_length,
}
)
return config

@classproperty
def tokenizer_cls(cls):
return BartTokenizer
Expand Down
45 changes: 20 additions & 25 deletions keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,6 @@ class BartSeq2SeqLMPreprocessor(BartPreprocessor):
tokenizer: A `keras_nlp.models.BartTokenizer` instance.
encoder_sequence_length: The length of the packed encoder inputs.
decoder_sequence_length: The length of the packed decoder inputs.
truncate: string. The algorithm to truncate a list of batched segments
to fit within `sequence_length`. The value can be either
`round_robin` or `waterfall`:
- `"round_robin"`: Available space is assigned one token at a
time in a round-robin fashion to the inputs that still need
some, until the limit is reached.
- `"waterfall"`: The allocation of the budget is done using a
"waterfall" algorithm that allocates quota in a
left-to-right manner and fills up the buckets until we run
out of budget. It supports an arbitrary number of segments.

Call arguments:
x: A dictionary with `encoder_text` and `decoder_text` as its keys.
Expand Down Expand Up @@ -139,7 +129,6 @@ def __init__(
tokenizer,
encoder_sequence_length,
decoder_sequence_length,
truncate="round_robin",
**kwargs
):
# Since we truncate the last token from `decoder_token_ids`, we need to
Expand All @@ -156,16 +145,6 @@ def __init__(
self._encoder_sequence_length = encoder_sequence_length
self._decoder_sequence_length = decoder_sequence_length

def get_config(self):
config = super().get_config()
config.update(
{
"encoder_sequence_length": self._encoder_sequence_length,
"decoder_sequence_length": self._decoder_sequence_length,
}
)
return config

def call(self, x, y=None, sample_weight=None):
if y is not None or sample_weight is not None:
logging.warning(
Expand All @@ -191,10 +170,6 @@ def call(self, x, y=None, sample_weight=None):
sample_weight = decoder_padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)

def generate_preprocess(
self,
x,
Expand All @@ -212,6 +187,9 @@ def generate_preprocess(
the decoder sequence (as generation is expected to continue at the end
of the inputted decoder prompt).
"""
if not self.built:
self.build(None)

# If `sequence_length` is not provided, we use the default value.
if sequence_length is None:
sequence_length = self._decoder_sequence_length
Expand Down Expand Up @@ -262,6 +240,9 @@ def generate_postprocess(
padding and start/end tokens, and then converting the integer sequence
back to a string.
"""
if not self.built:
self.build(None)

decoder_token_ids, decoder_padding_mask = (
x["decoder_token_ids"],
x["decoder_padding_mask"],
Expand All @@ -279,3 +260,17 @@ def generate_postprocess(
decoder_token_ids, decoder_padding_mask
)
return self.tokenizer.detokenize(decoder_token_ids)

def get_config(self):
config = super().get_config()
config.update(
{
"encoder_sequence_length": self._encoder_sequence_length,
"decoder_sequence_length": self._decoder_sequence_length,
}
)
return config

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from keras_nlp.tests.test_case import TestCase


class BartPreprocessorTest(TestCase):
class BartSeq2SeqLMPreprocessorTest(TestCase):
def setUp(self):
self.vocab = ["<s>", "<pad>", "</s>", "air", "Ġair", "plane", "Ġat"]
self.vocab += ["port", "<mask>"]
Expand Down
49 changes: 30 additions & 19 deletions keras_nlp/models/bart/bart_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,34 +78,45 @@ class BartTokenizer(BytePairTokenizer):

def __init__(
self,
vocabulary,
merges,
vocabulary=None,
merges=None,
**kwargs,
):
# Special tokens.
start_token = "<s>"
pad_token = "<pad>"
end_token = "</s>"
self.start_token = "<s>"
self.pad_token = "<pad>"
self.end_token = "</s>"

super().__init__(
vocabulary=vocabulary,
merges=merges,
unsplittable_tokens=[start_token, pad_token, end_token],
unsplittable_tokens=[
self.start_token,
self.pad_token,
self.end_token,
],
**kwargs,
)

# Check whether special tokens are present in the vocabulary.
for token in [start_token, pad_token, end_token]:
if token not in self.get_vocabulary():
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
"`vocabulary` or use a pretrained `vocabulary` name."
)

self.start_token_id = self.token_to_id(start_token)
self.pad_token_id = self.token_to_id(pad_token)
self.end_token_id = self.token_to_id(end_token)
def set_vocabulary_and_merges(self, vocabulary, merges):
super().set_vocabulary_and_merges(vocabulary, merges)

if vocabulary is not None:
# Check for necessary special tokens.
for token in [self.start_token, self.pad_token, self.end_token]:
if token not in self.vocabulary:
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
"`vocabulary` or use a pretrained `vocabulary` name."
)

self.start_token_id = self.token_to_id(self.start_token)
self.pad_token_id = self.token_to_id(self.pad_token)
self.end_token_id = self.token_to_id(self.end_token)
else:
self.start_token_id = None
self.pad_token_id = None
self.end_token_id = None

@classproperty
def presets(cls):
Expand Down
20 changes: 10 additions & 10 deletions keras_nlp/models/bert/bert_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,6 @@ def __init__(
self.truncate = truncate
self.packer = None

def get_config(self):
config = super().get_config()
config.update(
{
"sequence_length": self.sequence_length,
"truncate": self.truncate,
}
)
return config

def build(self, input_shape):
# Defer packer creation to `build()` so that we can be sure tokenizer
# assets have loaded when restoring a saved model.
Expand All @@ -176,6 +166,16 @@ def call(self, x, y=None, sample_weight=None):
}
return pack_x_y_sample_weight(x, y, sample_weight)

def get_config(self):
config = super().get_config()
config.update(
{
"sequence_length": self.sequence_length,
"truncate": self.truncate,
}
)
return config

@classproperty
def tokenizer_cls(cls):
return BertTokenizer
Expand Down
18 changes: 9 additions & 9 deletions keras_nlp/models/bert/bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def __init__(
lowercase=False,
**kwargs,
):
self.cls_token = "[CLS]"
self.sep_token = "[SEP]"
self.pad_token = "[PAD]"
self.mask_token = "[MASK]"
super().__init__(
vocabulary=vocabulary,
lowercase=lowercase,
Expand All @@ -89,22 +93,18 @@ def set_vocabulary(self, vocabulary):

if vocabulary is not None:
# Check for necessary special tokens.
cls_token = "[CLS]"
sep_token = "[SEP]"
pad_token = "[PAD]"
mask_token = "[MASK]"
for token in [cls_token, pad_token, sep_token]:
for token in [self.cls_token, self.pad_token, self.sep_token]:
if token not in self.vocabulary:
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
"`vocabulary` or use a pretrained `vocabulary` name."
)

self.cls_token_id = self.token_to_id(cls_token)
self.sep_token_id = self.token_to_id(sep_token)
self.pad_token_id = self.token_to_id(pad_token)
self.mask_token_id = self.token_to_id(mask_token)
self.cls_token_id = self.token_to_id(self.cls_token)
self.sep_token_id = self.token_to_id(self.sep_token)
self.pad_token_id = self.token_to_id(self.pad_token)
self.mask_token_id = self.token_to_id(self.mask_token)
else:
self.cls_token_id = None
self.sep_token_id = None
Expand Down
20 changes: 10 additions & 10 deletions keras_nlp/models/distil_bert/distil_bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def __init__(
lowercase=False,
**kwargs,
):
self.cls_token = "[CLS]"
self.sep_token = "[SEP]"
self.pad_token = "[PAD]"
self.mask_token = "[MASK]"
super().__init__(
vocabulary=vocabulary,
lowercase=lowercase,
Expand All @@ -87,22 +91,18 @@ def set_vocabulary(self, vocabulary):

if vocabulary is not None:
# Check for necessary special tokens.
cls_token = "[CLS]"
sep_token = "[SEP]"
pad_token = "[PAD]"
mask_token = "[MASK]"
for token in [cls_token, pad_token, sep_token]:
if token not in self.get_vocabulary():
for token in [self.cls_token, self.pad_token, self.sep_token]:
if token not in self.vocabulary:
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
"`vocabulary` or use a pretrained `vocabulary` name."
)

self.cls_token_id = self.token_to_id(cls_token)
self.sep_token_id = self.token_to_id(sep_token)
self.pad_token_id = self.token_to_id(pad_token)
self.mask_token_id = self.token_to_id(mask_token)
self.cls_token_id = self.token_to_id(self.cls_token)
self.sep_token_id = self.token_to_id(self.sep_token)
self.pad_token_id = self.token_to_id(self.pad_token)
self.mask_token_id = self.token_to_id(self.mask_token)
else:
self.cls_token_id = None
self.sep_token_id = None
Expand Down
6 changes: 6 additions & 0 deletions keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def generate_preprocess(
the sequence (as generation is expected to continue at the end of the
inputted prompt).
"""
if not self.built:
self.build(None)

x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
token_ids, padding_mask = self.packer(
Expand All @@ -162,6 +165,9 @@ def generate_postprocess(
padding and start/end tokens, and then converting the integer sequence
back to a string.
"""
if not self.built:
self.build(None)

token_ids, padding_mask = x["token_ids"], x["padding_mask"]
token_ids = ops.convert_to_numpy(token_ids)
padding_mask = ops.convert_to_numpy(padding_mask)
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/gpt2/gpt2_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def setUp(self):
num_heads=2,
hidden_dim=4,
intermediate_dim=8,
max_sequence_length=self.preprocessor.packer.sequence_length,
max_sequence_length=self.preprocessor.sequence_length,
)
self.init_kwargs = {
"preprocessor": self.preprocessor,
Expand Down
Loading