Skip to content

Commit 9286561

Browse files
authored
Switch deberta to use the "int" dtype (keras-team#1315)
This will be int32 on jax and torch, but int64 on tf, which is what we need for proper accelerator support
1 parent fe9b868 commit 9286561

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

keras_nlp/models/deberta_v3/disentangled_self_attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,13 @@ def _get_log_pos(abs_pos, mid):
232232
x1=rel_pos,
233233
x2=log_pos * sign,
234234
)
235-
bucket_pos = ops.cast(bucket_pos, dtype="int64")
235+
bucket_pos = ops.cast(bucket_pos, dtype="int")
236236

237237
return bucket_pos
238238

239239
def _get_rel_pos(self, num_positions):
240-
ids = ops.arange(num_positions, dtype="int64")
240+
ids = ops.arange(num_positions)
241+
ids = ops.cast(ids, dtype="int")
241242
query_ids = ops.expand_dims(ids, axis=-1)
242243
key_ids = ops.expand_dims(ids, axis=0)
243244
key_ids = ops.repeat(key_ids, repeats=num_positions, axis=0)

0 commit comments

Comments
 (0)