Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
e2b569a
Add phi3
abuelnasr0 Apr 25, 2024
c963f7b
Add phi3 to init
abuelnasr0 Apr 25, 2024
0f30b5f
layer naming and some nits
abuelnasr0 Apr 25, 2024
1d18572
Decoder layers naming
abuelnasr0 Apr 25, 2024
3d24cb2
Remove bias from einsumdense
abuelnasr0 Apr 26, 2024
a98369a
nit fix for layernorm
abuelnasr0 Apr 26, 2024
aecd9e2
Add SuRotary embedding
abuelnasr0 Apr 26, 2024
368e5a2
Romve print()
abuelnasr0 Apr 26, 2024
4267257
Add conversion script
abuelnasr0 Apr 27, 2024
e946503
Nit fix in script
abuelnasr0 Apr 27, 2024
145864b
Add phi3_4k as default preset
abuelnasr0 Apr 27, 2024
c5c78ed
Fix Doc and nit changes
abuelnasr0 Apr 27, 2024
7b0def0
Nit in test
abuelnasr0 Apr 27, 2024
c78c482
Doc fix
abuelnasr0 Apr 29, 2024
cd0381a
Add length check for rope scaling factors
abuelnasr0 Apr 29, 2024
6f9108d
Calculate the mean of the absolute differnce in conversion script
abuelnasr0 Apr 29, 2024
8e99e04
Fix typo
abuelnasr0 Apr 29, 2024
b3ca8a3
Add tokenizer and preprocessor
abuelnasr0 May 2, 2024
b53a326
Format fix
abuelnasr0 May 2, 2024
0e37c9a
Fix dtype and device in conversion script
abuelnasr0 May 2, 2024
45ab340
Batch the input
abuelnasr0 May 2, 2024
a459038
Batch the input
abuelnasr0 May 2, 2024
9c38dec
Nit
abuelnasr0 May 2, 2024
07832e5
Add notify for upload
abuelnasr0 May 2, 2024
fc1cf0b
ADd causal_lm preprocessor
abuelnasr0 May 2, 2024
aac962d
Add causal lm
abuelnasr0 May 2, 2024
49103f3
Fix format
abuelnasr0 May 2, 2024
49a5495
small fixes
abuelnasr0 May 2, 2024
dcead36
Add phi3 to the new api
abuelnasr0 May 2, 2024
0d990f5
Api gen
abuelnasr0 May 2, 2024
ac8770d
Public named sublayers
abuelnasr0 May 6, 2024
1c2a70e
Publicc named sublayers in decoder layer
abuelnasr0 May 6, 2024
e63430a
Simplify dropout
abuelnasr0 May 6, 2024
9ac44b6
Fix tokenizer tests
abuelnasr0 May 6, 2024
6355e67
Fix conversion script
abuelnasr0 May 6, 2024
a206480
use preprocessor
abuelnasr0 May 6, 2024
968a220
use preprocessor
abuelnasr0 May 6, 2024
7c17bd1
Fix keras input
abuelnasr0 May 6, 2024
0c165a0
Fix keras model input
abuelnasr0 May 6, 2024
074ea69
Only validate with validate_dtype
abuelnasr0 May 6, 2024
d5c7fab
Only validate with validate_dtype
abuelnasr0 May 6, 2024
0368483
Change seq length
abuelnasr0 May 6, 2024
1eed34b
Change text
abuelnasr0 May 6, 2024
d048f2d
Set pad token id to 0
abuelnasr0 May 7, 2024
3af4096
Default stop at EOS and EOT
abuelnasr0 May 7, 2024
c9f0ad9
Add presets
abuelnasr0 May 7, 2024
5c5e4ef
Add presets and tests to tokenizer
abuelnasr0 May 7, 2024
5225c6f
Add prepreocessor preset tests
abuelnasr0 May 7, 2024
b02c0b4
Add preset tests to causal_lm
abuelnasr0 May 7, 2024
4ab0d32
Add backbone preset tests
abuelnasr0 May 7, 2024
b2b7c55
Naming nits
abuelnasr0 May 7, 2024
0dff9f1
Clean surotaryembeddding
abuelnasr0 May 7, 2024
9552750
Lower case file name
abuelnasr0 May 9, 2024
b76f314
Save SuScaled rope factors as python lists
abuelnasr0 May 9, 2024
ad585b0
Rename orignal_max seq_length to training seq_length
abuelnasr0 May 9, 2024
9f10b63
Foemat
abuelnasr0 May 9, 2024
55e15bf
Remove placeholders tokens from spm
abuelnasr0 May 9, 2024
19fc9ca
Edit examples
abuelnasr0 May 9, 2024
f0a4236
Nit in generate
abuelnasr0 May 10, 2024
c205c20
Change training_seq_length to pretraining_seq_length
abuelnasr0 May 14, 2024
b735170
Update links
mattdangerw May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Change training_seq_length to pretraining_seq_length
  • Loading branch information
abuelnasr0 committed May 14, 2024
commit c205c2070fbf3653981497985b58c1d37535c5b3
8 changes: 4 additions & 4 deletions keras_nlp/src/models/phi3/phi3_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
kernel_initializer="glorot_uniform",
dropout=0,
max_sequence_length=4096,
training_sequence_length=4096,
pretraining_sequence_length=4096,
rope_max_wavelength=10000,
rope_scaling_type=None,
rope_scaling_short_factor=None,
Expand All @@ -44,7 +44,7 @@ def __init__(
self.dropout = dropout

self.max_sequence_length = max_sequence_length
self.training_sequence_length = training_sequence_length
self.pretraining_sequence_length = pretraining_sequence_length
self.rope_max_wavelength = rope_max_wavelength
self.rope_scaling_type = rope_scaling_type
self.rope_scaling_short_factor = rope_scaling_short_factor
Expand Down Expand Up @@ -148,7 +148,7 @@ def build(self, inputs_shape):
inverese_freq_short_factor=self.rope_scaling_short_factor,
inverese_freq_long_factor=self.rope_scaling_long_factor,
max_sequence_length=self.max_sequence_length,
training_sequence_length=self.training_sequence_length,
pretraining_sequence_length=self.pretraining_sequence_length,
max_wavelength=self.rope_max_wavelength,
dtype=self.dtype_policy,
)
Expand Down Expand Up @@ -249,7 +249,7 @@ def get_config(self):
),
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"training_sequence_length": self.training_sequence_length,
"pretraining_sequence_length": self.pretraining_sequence_length,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_type": self.rope_scaling_type,
"rope_scaling_short_factor": self.rope_scaling_short_factor,
Expand Down
12 changes: 6 additions & 6 deletions keras_nlp/src/models/phi3/phi3_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class Phi3Backbone(Backbone):
decoder.
max_sequence_length (int, optional): The maximum sequence length
that this model might ever be used with. Defaults to `4096`.
training_sequence_length (int, optional): The maximum sequence length
that the model was trained with. Defaults to `4096`.
pretraining_sequence_length (int, optional): The maximum sequence length
that the model was pretrained with. Defaults to `4096`.
rope_max_wavelength (int, optional): The maximum angular wavelength of
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
rope_scaling_type (str, optional): The type of the rope scaling. Can be
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(
layer_norm_epsilon=1e-6,
dropout=0.0,
max_sequence_length=4096,
training_sequence_length=4096,
pretraining_sequence_length=4096,
rope_max_wavelength=10000,
rope_scaling_type=None,
rope_scaling_short_factor=None,
Expand Down Expand Up @@ -148,7 +148,7 @@ def __init__(
kernel_initializer=_phi3_kernel_initializer(stddev=0.02),
dropout=dropout,
max_sequence_length=max_sequence_length,
training_sequence_length=training_sequence_length,
pretraining_sequence_length=pretraining_sequence_length,
rope_scaling_type=rope_scaling_type,
rope_scaling_short_factor=rope_scaling_short_factor,
rope_scaling_long_factor=rope_scaling_long_factor,
Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(
self.layer_norm_epsilon = layer_norm_epsilon
self.dropout = dropout
self.max_sequence_length = max_sequence_length
self.training_sequence_length = training_sequence_length
self.pretraining_sequence_length = pretraining_sequence_length
self.rope_max_wavelength = rope_max_wavelength
self.rope_scaling_type = rope_scaling_type
self.rope_scaling_short_factor = rope_scaling_short_factor
Expand All @@ -213,7 +213,7 @@ def get_config(self):
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"training_sequence_length": self.training_sequence_length,
"pretraining_sequence_length": self.pretraining_sequence_length,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_type": self.rope_scaling_type,
"rope_scaling_short_factor": self.rope_scaling_short_factor,
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/src/models/phi3/phi3_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def setUp(self):
"hidden_dim": 8,
"intermediate_dim": 12,
"max_sequence_length": 10,
"training_sequence_length": 5,
"pretraining_sequence_length": 5,
"rope_scaling_type": "su",
"rope_scaling_short_factor": [1.2, 1.4],
"rope_scaling_long_factor": [0.8, 0.6],
Expand Down
8 changes: 4 additions & 4 deletions keras_nlp/src/models/phi3/phi3_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
kernel_initializer="glorot_uniform",
dropout=0,
max_sequence_length=4096,
training_sequence_length=4096,
pretraining_sequence_length=4096,
rope_max_wavelength=10000,
rope_scaling_type=None,
rope_scaling_short_factor=None,
Expand All @@ -52,7 +52,7 @@ def __init__(
self.num_key_value_heads = num_key_value_heads

self.max_sequence_length = max_sequence_length
self.training_sequence_length = training_sequence_length
self.pretraining_sequence_length = pretraining_sequence_length
self.rope_max_wavelength = rope_max_wavelength
self.rope_scaling_type = rope_scaling_type
self.rope_scaling_short_factor = rope_scaling_short_factor
Expand Down Expand Up @@ -81,7 +81,7 @@ def build(self, decoder_sequence_shape):
kernel_initializer=clone_initializer(self.kernel_initializer),
dropout=self.dropout,
max_sequence_length=self.max_sequence_length,
training_sequence_length=self.training_sequence_length,
pretraining_sequence_length=self.pretraining_sequence_length,
rope_max_wavelength=self.rope_max_wavelength,
rope_scaling_type=self.rope_scaling_type,
rope_scaling_short_factor=self.rope_scaling_short_factor,
Expand Down Expand Up @@ -249,7 +249,7 @@ def get_config(self):
),
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"training_sequence_length": self.training_sequence_length,
"pretraining_sequence_length": self.pretraining_sequence_length,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_type": self.rope_scaling_type,
"rope_scaling_short_factor": self.rope_scaling_short_factor,
Expand Down
16 changes: 8 additions & 8 deletions keras_nlp/src/models/phi3/phi3_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class Phi3SuScaledRotaryEmbedding(RotaryEmbedding):
`sequence_length` is larger than `original_max_sequence_length`.
max_sequence_length: int. The maximum sequence length that this
model might ever be used with.
training_sequence_length: int. The maximum sequence length that
this model was trained with.
pretraining_sequence_length: int. The maximum sequence length that
this model was pretrained with.
max_wavelength: int. The maximum angular wavelength of the sine/cosine
curves.

Expand All @@ -53,24 +53,24 @@ def __init__(
inverese_freq_short_factor,
inverese_freq_long_factor,
max_sequence_length=4096,
training_sequence_length=4096,
pretraining_sequence_length=4096,
max_wavelength=10000,
**kwargs
):
super().__init__(max_wavelength=max_wavelength, **kwargs)
self.max_sequence_length = max_sequence_length
self.training_sequence_length = training_sequence_length
self.pretraining_sequence_length = pretraining_sequence_length

scaling_factor = (
self.max_sequence_length / self.training_sequence_length
self.max_sequence_length / self.pretraining_sequence_length
)
if scaling_factor <= 1.0:
self.embedding_scaling_factor = 1.0
else:
self.embedding_scaling_factor = math.sqrt(
1
+ math.log(scaling_factor)
/ math.log(self.training_sequence_length)
/ math.log(self.pretraining_sequence_length)
)

self.inverese_freq_short_factor = inverese_freq_short_factor
Expand All @@ -84,7 +84,7 @@ def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
inverse_freq = self._get_inverse_freq(rotary_dim)

# Multiply inverse_freq by a factor.
if ops.shape(inputs)[sequence_axis] > self.training_sequence_length:
if ops.shape(inputs)[sequence_axis] > self.pretraining_sequence_length:
inverse_freq = ops.divide(
inverse_freq,
ops.convert_to_tensor(self.inverese_freq_long_factor),
Expand Down Expand Up @@ -128,7 +128,7 @@ def get_config(self):
config.update(
{
"max_sequence_length": self.max_sequence_length,
"training_sequence_length": self.training_sequence_length,
"pretraining_sequence_length": self.pretraining_sequence_length,
"inverese_freq_short_factor": self.inverese_freq_short_factor,
"inverese_freq_long_factor": self.inverese_freq_long_factor,
}
Expand Down