Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a76a91b
Add AlibiBias layer
abuelnasr0 Jan 18, 2024
3f3d968
Add example
abuelnasr0 Jan 18, 2024
f1536df
Convert layer to recieve attn_scores and add the alibi bias to it.
abuelnasr0 Jan 22, 2024
9c891ae
Change layer logic
abuelnasr0 Jan 22, 2024
3676476
Add layer test
abuelnasr0 Jan 22, 2024
358e6f0
Format the code
abuelnasr0 Jan 22, 2024
0c949c4
Fix seq_range creation to be int range
abuelnasr0 Jan 22, 2024
5f1ebc1
Change bloom model to use alibi bias
abuelnasr0 Jan 22, 2024
aa3b15f
Format the code
abuelnasr0 Jan 22, 2024
f462ee2
Remove print function
abuelnasr0 Jan 23, 2024
726946d
Change logic to only compute alibi bias once
abuelnasr0 Jan 23, 2024
56a8c59
Change bloom model API calls to much new alibi layer API
abuelnasr0 Jan 23, 2024
7408fd8
Format the code
abuelnasr0 Jan 23, 2024
91c0e04
Add dtype kwarg for the layer weight
abuelnasr0 Jan 23, 2024
410385f
Revert "Add dtype kwarg for the layer weight"
abuelnasr0 Jan 23, 2024
8fc2616
Cast after adding alibi bias
abuelnasr0 Jan 23, 2024
8e750b7
Return to compute ALibi bias at each call
abuelnasr0 Jan 25, 2024
a44c6e8
Add compute output shape method for bloom decoder
abuelnasr0 Jan 25, 2024
7f8cc0f
Force shape to be (batch_size, num_heads, query_length, key_length)
abuelnasr0 Jan 25, 2024
9440bd6
Format the code
abuelnasr0 Jan 25, 2024
795eacd
Fix documentation
abuelnasr0 Jan 25, 2024
071189f
Fix the example
abuelnasr0 Jan 25, 2024
f6bf10d
Fix tensorflow2 test fail
abuelnasr0 Jan 25, 2024
4f8fcd0
Fix tensorflow2 test fail
abuelnasr0 Jan 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Format the code
  • Loading branch information
abuelnasr0 committed Jan 22, 2024
commit aa3b15ff602457b6aae6be62b7eb39d2c739c518
7 changes: 4 additions & 3 deletions keras_nlp/models/bloom/bloom_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def build(self, inputs_shape):

self.built = True


def call(
self,
hidden_states,
Expand Down Expand Up @@ -136,10 +135,12 @@ def call(
key = ops.transpose(key, [0, 2, 3, 1])

attention_scores = (
ops.matmul(query, key) * self.inv_norm_factor
ops.matmul(query, key) * self.inv_norm_factor
) # [batch_size, num_heads, query_length, kv_length]
attention_scores = self._alibi_layer(attention_scores)
attention_scores = self._softmax(attention_scores, ops.expand_dims(attention_mask, 1))
attention_scores = self._softmax(
attention_scores, ops.expand_dims(attention_mask, 1)
)
attention_scores = self._dropout_layer(attention_scores)

attention_output = ops.matmul(
Expand Down