Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
cd7e2f8
Add VIT Encoder
divyashreepathihalli Mar 28, 2024
53119bd
Add MHAPooling layer + end to end ViT model
fchollet Mar 29, 2024
aa9bd89
Feature/pg gemma changes
VarunS1997 Apr 2, 2024
962aea7
Misc fixes
divyashreepathihalli Apr 2, 2024
f23974b
update vit model and add a test for verifying output shape.
divyashreepathihalli Apr 3, 2024
53865e0
Feature/pg gemma changes
VarunS1997 Apr 18, 2024
86540c0
Vit model weights conversion
divyashreepathihalli Apr 22, 2024
6160ebe
Update imports
divyashreepathihalli Apr 25, 2024
92e697d
Add vit attention
divyashreepathihalli Apr 27, 2024
0454f1f
add paligemma functional model
divyashreepathihalli May 3, 2024
56a95ce
Paligemma full model checkpoints conversion script
divyashreepathihalli May 4, 2024
9a46b6e
Fix ViT build issue
divyashreepathihalli May 4, 2024
2a2e6b1
Multi modal Refactor for PaliGemma
VarunS1997 May 6, 2024
b2a14d8
Export the public API surface
mattdangerw May 7, 2024
53266b1
update image size arg throughout PaliGemma
divyashreepathihalli May 7, 2024
597daaf
Update convert_paligemma_checkpoints.py
divyashreepathihalli May 7, 2024
87d6c5a
Renames for consistency
mattdangerw May 7, 2024
3ab245c
Update convert_pali_gemma_checkpoints.py
divyashreepathihalli May 7, 2024
671a161
More consistency improvements for PaliGemma
mattdangerw May 7, 2024
ffd5d98
Do the same scaling in our backbone we do for generate
mattdangerw May 7, 2024
b94b6d3
Update conversion and add cli arguments
grasskin May 7, 2024
d262569
Add presets
divyashreepathihalli May 7, 2024
caa7cef
Update pali_gemma_causal_lm_preprocesor.py to default text sequence l…
divyashreepathihalli May 8, 2024
c8f327a
Tokenizer fix
divyashreepathihalli May 8, 2024
bc3811b
Allow fit calls for pali_gemma
mattdangerw May 8, 2024
1afea65
Allow generate on unbatch input
mattdangerw May 8, 2024
cfef757
Remove the score function from pali gemma causal lm
mattdangerw May 8, 2024
0ba0920
Greedy sample by default for pali gemma
mattdangerw May 8, 2024
f4d31dd
Added docstrings for paligemma decoder and backbone
VarunS1997 May 10, 2024
e8bff89
Minor fixes for the vit
mattdangerw May 16, 2024
db0d9d1
Add a response_mask input
mattdangerw May 21, 2024
5267d5a
Update pali_gemma_presets.py path
divyashreepathihalli May 21, 2024
83ee31d
Add a tokenizer docstring for pali gemma
mattdangerw May 21, 2024
edc66e8
Update pali_gemma_presets.py
divyashreepathihalli May 21, 2024
569d89e
More consistent defaults for PaliGemma
mattdangerw May 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add vit attention
  • Loading branch information
divyashreepathihalli authored and mattdangerw committed May 21, 2024
commit 92e697def93514da3538eccf5074e485f37a693b
159 changes: 133 additions & 26 deletions keras_nlp/src/models/paligemma/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,129 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.src.backend import config
from keras_nlp.src.backend import keras
from keras_nlp.src.backend import ops
from keras_nlp.src.models.paligemma.vision_embeddings import VisionEmbeddings


class PaliGemmaAttention(keras.layers.Layer):
"""
Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py # noqa: E501
"""

def __init__(self, hidden_dim, num_heads, dropout=0.0, **kwargs):
super().__init__(**kwargs)

self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = self.hidden_dim // self.num_heads
if self.head_dim * self.num_heads != self.hidden_dim:
raise ValueError(
f"hidden_dim must be divisible by num_heads (got `hidden_dim`"
f": {self.hidden_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.dropout_layer = keras.layers.Dropout(self.dropout)
self.scale = self.head_dim**-0.5
self.query_proj = keras.layers.Dense(
units=self.hidden_dim,
name="query_proj",
)
self.key_proj = keras.layers.Dense(
units=self.hidden_dim,
name="key_proj",
)
self.value_proj = keras.layers.Dense(
units=self.hidden_dim,
name="value_proj",
)
self.out_proj = keras.layers.Dense(
units=self.hidden_dim,
name="out_proj",
)

def build(self, input_shape):
self.query_proj.build([None, None, self.hidden_dim])
self.key_proj.build([None, None, self.hidden_dim])
self.value_proj.build([None, None, self.hidden_dim])
self.out_proj.build([None, None, self.hidden_dim])
self.built = True

def _transpose_for_scores(self, tensor, batch_size):
"""
Adapted from https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/bert/modeling_tf_bert.py#L252 # noqa: E501
"""
# [batch_size, seq_len, all_head_dim] ->
# [batch_size, seq_len, num_heads, head_dim]
tensor = ops.reshape(
tensor, (batch_size, -1, self.num_heads, self.head_dim)
)
# [batch_size, seq_len, num_heads, head_dim] ->
# [batch_size, num_heads, seq_len, head_dim]
return ops.transpose(tensor, axes=[0, 2, 1, 3])

def call(
self,
x,
attention_mask=None,
return_attention_scores=None,
training=False,
):
batch_size = ops.shape(x)[0]
mixed_query_layer = self.query_proj(inputs=x)
mixed_key_layer = self.key_proj(inputs=x)
mixed_value_layer = self.value_proj(inputs=x)
query_layer = self._transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self._transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self._transpose_for_scores(mixed_value_layer, batch_size)

# Scaled dot product between key and query = raw attention scores.
attention_scores = ops.matmul(
query_layer, ops.transpose(key_layer, axes=[0, 1, 3, 2])
)
dk = ops.cast(ops.sqrt(self.head_dim), dtype=attention_scores.dtype)
attention_scores = ops.divide(
attention_scores, dk
) # (batch_size, num_heads, seq_len_q, seq_len_k)

if attention_mask is not None:
# Apply the attention mask (precomputed for all layers in the
# call() function)
attention_scores = ops.add(attention_scores, attention_mask)

# Normalize the attention scores to probabilities.
attention_probs = ops.softmax(attention_scores, axis=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
dropout_attention_probs = self.dropout_layer(
inputs=attention_probs, training=training
)

attn_output = ops.matmul(dropout_attention_probs, value_layer)
attn_output = ops.transpose(attn_output, axes=[0, 2, 1, 3])

# (batch_size, seq_len_q, hidden_dim)
attn_output = ops.reshape(
attn_output, (batch_size, -1, self.hidden_dim)
)

attn_output = self.out_proj(attn_output, training=training)
return (attn_output, attention_probs)

def get_config(self):
config = super().get_config()
config.update(
{
"hidden_dim": self.hidden_dim,
"num_heads": self.num_heads,
"dropout": self.dropout,
}
)
return config


class VitEncoderBlock(keras.layers.Layer):
def __init__(
self,
Expand All @@ -35,39 +152,31 @@ def compute_attention(self, x, mask=None):
mask = ops.cast(mask, dtype=x.dtype) if mask is not None else None

return self.attn(
x,
x,
x,
attention_mask=mask,
return_attention_scores=True,
)[0]

def build(self, input_shape):
self.hidden_dim = input_shape[-1]
self.attn = keras.layers.MultiHeadAttention(
self.attn = PaliGemmaAttention(
self.hidden_dim,
self.num_heads,
key_dim=self.hidden_dim // self.num_heads,
name="multi_head_attention",
)
self.layer_norm_1 = keras.layers.LayerNormalization(
epsilon=1e-5, name="layer_norm_1"
epsilon=1e-6, name="layer_norm_1"
)
self.mlp_dense_1 = keras.layers.Dense(
self.intermediate_size,
name="mlp_dense_1",
activation="gelu",
self.intermediate_size, name="mlp_dense_1"
)
self.mlp_dense_2 = keras.layers.Dense(
self.hidden_dim,
name="mlp_dense_2",
)
self.layer_norm_2 = keras.layers.LayerNormalization(
epsilon=1e-5, name="layer_norm_2"
)
self.attn.build(
[None, None, self.hidden_dim],
[None, None, self.hidden_dim],
epsilon=1e-6, name="layer_norm_2"
)
self.attn.build(None)
self.layer_norm_1.build([None, None, self.hidden_dim])
self.mlp_dense_1.build([None, None, self.hidden_dim])
self.mlp_dense_2.build([None, None, self.intermediate_size])
Expand All @@ -77,10 +186,12 @@ def build(self, input_shape):
def call(self, x, mask=None):
residual = x
x = self.layer_norm_1(x)
# mask = ops.ones_like(x) if mask is None else mask
x = self.compute_attention(x, mask)
x = x + residual
residual = x
x = self.mlp_dense_1(self.layer_norm_2(residual))
x = keras.activations.gelu(x, approximate=True)
x = self.mlp_dense_2(x)
return residual + x

Expand All @@ -103,20 +214,13 @@ class VitEncoder(keras.layers.Layer):
def __init__(
self, hidden_dim, num_layers, num_heads, intermediate_size, **kwargs
):
if not config.keras_3():
raise ValueError(
"`PaLIGemma` requires Keras 3. Run `pip install -U keras` "
"upgrade your Keras version, or see "
"https://keras.io/getting_started/ "
"for more info on Keras versions and installation."
)
super().__init__(**kwargs)
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.intermediate_size = intermediate_size
self.encoder_layer_norm = keras.layers.LayerNormalization(
epsilon=1e-5, name="encoder_layer_norm"
epsilon=1e-6, name="encoder_layer_norm"
)
self.vision_embeddings = VisionEmbeddings(hidden_dim=hidden_dim)
self.resblocks = [
Expand All @@ -142,7 +246,8 @@ def call(
x = self.vision_embeddings(x)
for block in self.resblocks:
x = block(x, mask=mask)
return self.encoder_layer_norm(x)
x = self.encoder_layer_norm(x)
return x

def compute_output_shape(self, inputs_shape):
return [inputs_shape[0], inputs_shape[1], self.hidden_dim]
Expand Down Expand Up @@ -177,7 +282,7 @@ def build(self, input_shape):
key_dim=input_shape[-1] // self.num_heads,
num_heads=self.num_heads,
)
self.layer_norm = keras.layers.LayerNormalization()
self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-6)
self.mlp_block = keras.Sequential(
[
keras.layers.Dense(self.hidden_dim, activation="gelu"),
Expand Down Expand Up @@ -215,6 +320,8 @@ def __init__(
if include_rescaling:
x = keras.layers.Rescaling(scale=1 / 255.0)(inputs)

self.pooled = None

encoded = VitEncoder(
hidden_dim,
num_layers,
Expand All @@ -238,10 +345,10 @@ def __init__(
"Expected one of 'map', 'gap', None. "
f"Received: pooling={pooling}"
)

outputs = keras.layers.Dense(
num_classes, activation=classifier_activation, name="classifier"
)(pooled)
self.pooled = pooled
super().__init__(inputs=inputs, outputs=outputs, name=name, **kwargs)

self.num_heads = num_heads
Expand Down
Loading