Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fixes for rotary embedding
  • Loading branch information
mattdangerw committed Dec 22, 2023
commit 7eb04e09b3e0cc8231ab2df82b6c19701ef900bd
8 changes: 5 additions & 3 deletions keras_nlp/layers/modeling/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,18 @@ def get_axis(axis):
sequence_axis = get_axis(self.sequence_axis)

rotary_dim = ops.shape(inputs)[feature_axis]

inverse_freq = self._get_inverse_freq(rotary_dim)

seq_len = ops.shape(inputs)[self.sequence_axis]
tensor = ops.cast(ops.arange(seq_len), self.compute_dtype) + start_index

tensor = ops.cast(tensor, dtype=inverse_freq.dtype)
freq = ops.einsum("i, j -> ij", tensor, inverse_freq)
embedding = ops.concatenate((freq, freq), axis=self.feature_axis)
freq = ops.einsum("i,j->ij", tensor, inverse_freq)
embedding = ops.concatenate((freq, freq), axis=-1)

# Reshape the embedding to be broadcastable with input shape.
if feature_axis < sequence_axis:
embedding = ops.transpose(embedding)
for axis in range(len(inputs.shape)):
if axis != sequence_axis and axis != feature_axis:
embedding = ops.expand_dims(embedding, axis)
Expand Down
12 changes: 12 additions & 0 deletions keras_nlp/layers/modeling/rotary_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ def test_start_index(self):
)
self.assertAllClose(full_output, sequential_output)

def test_permuted_axes(self):
batch_size, seq_length, feature_size = 2, 3, 4
data = random.uniform(shape=(batch_size, seq_length, feature_size))
layer = RotaryEmbedding(seq_length)
outputs = layer(data)
permuted_data = ops.transpose(data, (0, 2, 1))
permuted_layer = RotaryEmbedding(
seq_length, sequence_axis=-1, feature_axis=-2
)
permuted_outputs = permuted_layer(permuted_data)
self.assertAllClose(outputs, ops.transpose(permuted_outputs, (0, 2, 1)))

def test_float16_dtype(self):
embedding_layer = RotaryEmbedding(dtype="float16")
seq_length = 100
Expand Down