Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions keras_nlp/src/layers/modeling/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class RotaryEmbedding(keras.layers.Layer):
Args:
max_wavelength: int. The maximum angular wavelength of the sine/cosine
curves.
scaling_factor: float. The scaling factor used to scale frequency range.
scaling_factor: float. The scaling factor used to scale positions of
the tokens.
sequence_axis: int. Sequence axis in the input tensor.
feature_axis: int. Feature axis in the input tensor.
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
Expand Down Expand Up @@ -125,6 +126,7 @@ def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
else:
positions = ops.cast(positions, "float32")

positions = positions / ops.cast(self.scaling_factor, "float32")
freq = ops.einsum("i,j->ij", positions, inverse_freq)
embedding = ops.stack((freq, freq), axis=-2)
embedding = ops.reshape(
Expand All @@ -143,12 +145,11 @@ def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
return cos_emb, sin_emb

def _get_inverse_freq(self, rotary_dim):
freq_range = ops.arange(0, rotary_dim, 2, dtype="float32")
freq_range = freq_range / ops.cast(self.scaling_factor, "float32")
inverse_freq = 1.0 / (
self.max_wavelength
** (freq_range / ops.cast(rotary_dim, "float32"))
freq_range = ops.divide(
ops.arange(0, rotary_dim, 2, dtype="float32"),
ops.cast(rotary_dim, "float32"),
)
inverse_freq = 1.0 / (self.max_wavelength**freq_range)
return inverse_freq

def get_config(self):
Expand Down
73 changes: 73 additions & 0 deletions keras_nlp/src/layers/modeling/rotary_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,76 @@ def test_positions_array(self):
got = layer(x, positions=positions)

np.testing.assert_allclose(expected, ops.convert_to_numpy(got))

def test_rope_scaling(self):
# Reference values computed from Huggingface llama implementation
# With `scaling_factor` = 2.0
# from transformers.models.llama.modeling_llama import (
# LlamaLinearScalingRotaryEmbedding,apply_rotary_pos_emb
# )
# import torch
# torch.set_printoptions(precision=9)
# rotary_emb = LlamaLinearScalingRotaryEmbedding(
# dim=4, max_position_embeddings=3, scaling_factor=2.0
# )
# query = torch.ones((1, 2, 3, 4)) # [bsz, num_heads, seq_len, head_dim]
# cos, sin = rotary_emb(
# query, torch.unsqueeze(torch.arange(3, dtype=torch.int32), 0)
# )
# query, _ = apply_rotary_pos_emb(query, query, cos, sin)
# print(query.transpose(1, 2))
expected = [
[
[
[1.000000000, 1.000000000, 1.000000000, 1.000000000],
[1.000000000, 1.000000000, 1.000000000, 1.000000000],
],
[
[0.398157001, 0.994987488, 1.357008100, 1.004987478],
[0.398157001, 0.994987488, 1.357008100, 1.004987478],
],
[
[-0.301168621, 0.989950180, 1.381773233, 1.009949803],
[-0.301168621, 0.989950180, 1.381773233, 1.009949803],
],
]
]

layer = RotaryEmbedding(scaling_factor=2.0)
self.assertAllClose(
layer(ops.ones((1, 3, 2, 4))),
ops.convert_to_tensor(expected),
)

def test_rope_scaling_with_kv_cache(self):
# Reference values computed from Huggingface llama implementation
# With `scaling_factor` = 5.0
# from transformers.models.llama.modeling_llama import (
# LlamaLinearScalingRotaryEmbedding,apply_rotary_pos_emb
# )
# import torch
# torch.set_printoptions(precision=9)
# rotary_emb = LlamaLinearScalingRotaryEmbedding(
# dim=4, max_position_embeddings=3, scaling_factor=5.0
# )

# query = torch.ones((1, 2, 1, 4)) # [bsz, num_heads, seq_len, head_dim]
# cos, sin = rotary_emb(
# query, torch.unsqueeze(torch.arange(12, 13, dtype=torch.int32), 0)
# )
# query, _ = apply_rotary_pos_emb(query, query, cos, sin)
# query.transpose(1, 2)
expected = [
[
[
[-1.412856817, 0.975714266, -0.061930716, 1.023709655],
[-1.412856817, 0.975714266, -0.061930716, 1.023709655],
]
]
]

layer = RotaryEmbedding(scaling_factor=5.0)
self.assertAllClose(
layer(ops.ones((1, 1, 2, 4)), start_index=12),
ops.convert_to_tensor(expected),
)