Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
cd7e2f8
Add VIT Encoder
divyashreepathihalli Mar 28, 2024
53119bd
Add MHAPooling layer + end to end ViT model
fchollet Mar 29, 2024
aa9bd89
Feature/pg gemma changes
VarunS1997 Apr 2, 2024
962aea7
Misc fixes
divyashreepathihalli Apr 2, 2024
f23974b
update vit model and add a test for verifying output shape.
divyashreepathihalli Apr 3, 2024
53865e0
Feature/pg gemma changes
VarunS1997 Apr 18, 2024
86540c0
Vit model weights conversion
divyashreepathihalli Apr 22, 2024
6160ebe
Update imports
divyashreepathihalli Apr 25, 2024
92e697d
Add vit attention
divyashreepathihalli Apr 27, 2024
0454f1f
add paligemma functional model
divyashreepathihalli May 3, 2024
56a95ce
Paligemma full model checkpoints conversion script
divyashreepathihalli May 4, 2024
9a46b6e
Fix ViT build issue
divyashreepathihalli May 4, 2024
2a2e6b1
Multi modal Refactor for PaliGemma
VarunS1997 May 6, 2024
b2a14d8
Export the public API surface
mattdangerw May 7, 2024
53266b1
update image size arg throughout PaliGemma
divyashreepathihalli May 7, 2024
597daaf
Update convert_paligemma_checkpoints.py
divyashreepathihalli May 7, 2024
87d6c5a
Renames for consistency
mattdangerw May 7, 2024
3ab245c
Update convert_pali_gemma_checkpoints.py
divyashreepathihalli May 7, 2024
671a161
More consistency improvements for PaliGemma
mattdangerw May 7, 2024
ffd5d98
Do the same scaling in our backbone we do for generate
mattdangerw May 7, 2024
b94b6d3
Update conversion and add cli arguments
grasskin May 7, 2024
d262569
Add presets
divyashreepathihalli May 7, 2024
caa7cef
Update pali_gemma_causal_lm_preprocesor.py to default text sequence l…
divyashreepathihalli May 8, 2024
c8f327a
Tokenizer fix
divyashreepathihalli May 8, 2024
bc3811b
Allow fit calls for pali_gemma
mattdangerw May 8, 2024
1afea65
Allow generate on unbatch input
mattdangerw May 8, 2024
cfef757
Remove the score function from pali gemma causal lm
mattdangerw May 8, 2024
0ba0920
Greedy sample by default for pali gemma
mattdangerw May 8, 2024
f4d31dd
Added docstrings for paligemma decoder and backbone
VarunS1997 May 10, 2024
e8bff89
Minor fixes for the vit
mattdangerw May 16, 2024
db0d9d1
Add a response_mask input
mattdangerw May 21, 2024
5267d5a
Update pali_gemma_presets.py path
divyashreepathihalli May 21, 2024
83ee31d
Add a tokenizer docstring for pali gemma
mattdangerw May 21, 2024
edc66e8
Update pali_gemma_presets.py
divyashreepathihalli May 21, 2024
569d89e
More consistent defaults for PaliGemma
mattdangerw May 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Feature/pg gemma changes
Co-authored-by: divyashreepathihalli <[email protected]>
  • Loading branch information
2 people authored and mattdangerw committed May 21, 2024
commit 53865e0982da1055cdd30b30fa0a3018000ea152
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@
from keras_nlp.src.layers.modeling.reversible_embedding import (
ReversibleEmbedding,
)
from keras_nlp.src.models.gemma.gemma_decoder_block import GemmaDecoderBlock
from keras_nlp.src.models.gemma.gemma_tokenizer import GemmaTokenizer
from keras_nlp.src.models.backbone import Backbone
from keras_nlp.src.models.gemma.rms_normalization import RMSNormalization
from keras_nlp.src.models.paligemma.pali_gemma_decoder_block import (
PaliGemmaDecoderBlock,
)


class PaliGemmaDecoder(keras.layers.Layer):
class PaliGemmaBackbone(Backbone):
def __init__(
self,
img_sequence_length,
vocabulary_size,
sequence_length,
num_layers,
num_query_heads,
num_key_value_heads,
Expand All @@ -34,13 +36,11 @@ def __init__(
head_dim,
layer_norm_epsilon=1e-6,
dropout=0,
tokenizer_preset="gemma_2b_en",
text_prefix="answer en",
dtype=None,
**kwargs,
):
self.img_sequence_length = img_sequence_length
self.vocabulary_size = vocabulary_size
self.sequence_length = sequence_length
self.num_layers = num_layers
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
Expand All @@ -49,15 +49,10 @@ def __init__(
self.head_dim = head_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.dropout = dropout
self.tokenizer_preset = tokenizer_preset
self.text_prefix = text_prefix
self.dtype = dtype

#
# Layers
#
self.tokenizer = GemmaTokenizer.from_preset(self.tokenizer_preset)

self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
Expand All @@ -71,11 +66,13 @@ def __init__(
dtype=dtype,
name="token_embedding",
)

self.transformer_layers = []
for i in range(num_layers):
layer = GemmaDecoderBlock(
intermediate_dim=intermediate_dim,
layer = PaliGemmaDecoderBlock(
img_sequence_length=img_sequence_length,
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
num_query_heads=num_query_heads,
head_dim=head_dim,
num_key_value_heads=num_key_value_heads,
Expand All @@ -94,42 +91,50 @@ def __init__(
# Functional Model
#
img_embeddings = keras.Input(
shape=(None,), dtype=dtype, name="img_embeddings"
shape=(img_sequence_length, hidden_dim),
dtype=dtype,
name="img_embeddings",
)

# TODO: Is there a good data type for text/string input like this?
text_in = keras.Input(
shape=(sequence_length, vocabulary_size), name="text"
token_ids = keras.Input(
shape=(None,),
dtype="int32",
name="token_ids",
)

prefixed_text = [self.text_prefix + " " + text for text in text_in]
padding_mask = keras.Input(
shape=(None,), dtype="float32", name="padding_mask"
)

tokenized_text = self.tokenizer(prefixed_text)
text_embeddings = self.token_embedding(tokenized_text)
text_embeddings = self.token_embedding(token_ids)

complete_sequence = keras.ops.concatenate(
(img_embeddings, text_embeddings), axis=1
)

transformer_out = complete_sequence
for transformer_layer in self.transformer_layers:
transformer_out = transformer_layer(transformer_out)
transformer_out = transformer_layer(
transformer_out, padding_mask=padding_mask
)

text_out = self.layer_norm(transformer_out)

super().__init__(
inputs={
"img_embeddings": img_embeddings,
"text": text_in,
"token_ids": token_ids,
"padding_mask": padding_mask,
},
outputs={"text_out": text_out},
outputs=text_out,
**kwargs,
)

def get_config(self):
config = super().get_config()
config.update(
{
"img_sequence_length": self.img_sequence_length,
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_query_heads": self.num_query_heads,
Expand All @@ -139,9 +144,6 @@ def get_config(self):
"head_dim": self.head_dim,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
"tokenizer_preset": self.tokenizer_preset,
"text_prefix": self.text_prefix,
"dtype": self.dtype,
}
)
return config
151 changes: 151 additions & 0 deletions keras_nlp/src/models/paligemma/pali_gemma_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import keras
import numpy as np
import pytest

from keras_nlp.src.models.gemma.gemma_preprocessor import GemmaPreprocessor
from keras_nlp.src.models.gemma.gemma_preprocessor import GemmaTokenizer
from keras_nlp.src.models.paligemma.pali_gemma_backbone import PaliGemmaBackbone
from keras_nlp.src.tests.test_case import TestCase


@pytest.mark.keras_3_only
class PaliGemmaBackboneTest(TestCase):
def test_paligemma_preprocessing(self):
batch_size = 1
text_sequence_length = 8
proto = os.path.join(self.get_test_data_dir(), "gemma_test_vocab.spm")

tokenizer = GemmaTokenizer(proto)

preprocessor = GemmaPreprocessor(
tokenizer, text_sequence_length, False, False
)

dummy_text = ["answer en the quick brown fox"]

output = preprocessor(dummy_text)
self.assertEqual(
(
batch_size,
text_sequence_length,
),
output["token_ids"].shape,
)

def test_paligemma_backbone(self):
batch_size = 2
img_sequence_length = 128
vocabulary_size = 256
text_sequence_length = 64
hidden_dim = 128
num_layers = 27
num_heads = 16
head_dim = 126
intermediate_size = 77

paligemma = PaliGemmaBackbone(
img_sequence_length,
vocabulary_size,
num_layers,
num_heads,
num_heads,
hidden_dim,
intermediate_size,
head_dim,
dtype="float32",
)

dummy_imgs = np.random.rand(batch_size, img_sequence_length, hidden_dim)
dummy_text = np.random.rand(batch_size, text_sequence_length)

output = paligemma(
inputs={
"token_ids": dummy_text,
"img_embeddings": dummy_imgs,
"padding_mask": np.ones(
(batch_size, text_sequence_length + img_sequence_length),
dtype="int32",
),
}
)
self.assertEqual(
(
batch_size,
text_sequence_length + img_sequence_length,
hidden_dim,
),
output.shape,
)

def test_complete_paligemma_backbone(self):
batch_size = 2
img_sequence_length = 128
vocabulary_size = 256
text_sequence_length = 64
hidden_dim = 128
num_layers = 27
num_heads = 16
head_dim = 126
intermediate_size = 77
proto = os.path.join(self.get_test_data_dir(), "gemma_test_vocab.spm")

tokenizer = GemmaTokenizer(proto)

preprocessor = GemmaPreprocessor(
tokenizer, text_sequence_length, False, False
)

paligemma = PaliGemmaBackbone(
img_sequence_length,
vocabulary_size,
num_layers,
num_heads,
num_heads,
hidden_dim,
intermediate_size,
head_dim,
dtype="float32",
)

dummy_imgs = keras.ops.convert_to_tensor(
np.random.rand(batch_size, img_sequence_length, hidden_dim)
)
dummy_text = [
"answer en the quick brown fox" for i in range(batch_size)
]

output = preprocessor(dummy_text)

output = paligemma(
inputs={
"token_ids": output["token_ids"],
"img_embeddings": dummy_imgs,
"padding_mask": keras.ops.ones(
(batch_size, text_sequence_length + img_sequence_length),
dtype="int32",
),
}
)
self.assertEqual(
(
batch_size,
text_sequence_length + img_sequence_length,
hidden_dim,
),
output.shape,
)
106 changes: 106 additions & 0 deletions keras_nlp/src/models/paligemma/pali_gemma_decoder_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.src.backend import ops
from keras_nlp.src.layers.modeling.transformer_layer_utils import (
compute_causal_mask,
)
from keras_nlp.src.layers.modeling.transformer_layer_utils import (
merge_padding_and_attention_mask,
)
from keras_nlp.src.models.gemma.gemma_decoder_block import GemmaDecoderBlock


class PaliGemmaDecoderBlock(GemmaDecoderBlock):
def __init__(
self,
img_sequence_length,
hidden_dim,
intermediate_dim,
head_dim,
num_query_heads,
num_key_value_heads,
layer_norm_epsilon=1e-6,
dropout=0,
**kwargs,
):
super().__init__(
hidden_dim,
intermediate_dim,
head_dim,
num_query_heads,
num_key_value_heads,
layer_norm_epsilon,
dropout,
**kwargs,
)

self.img_sequence_length = img_sequence_length

def _compute_attention_mask(
self, x, padding_mask, cache, cache_update_index
):
decoder_mask = merge_padding_and_attention_mask(
inputs=x, padding_mask=padding_mask, attention_mask=None
)
batch_size = ops.shape(x)[0]
input_length = output_length = ops.shape(x)[1]
if cache is not None:
input_length = ops.shape(cache)[2]

causal_mask = compute_causal_mask(
batch_size=batch_size,
input_length=input_length,
output_length=output_length,
cache_index=cache_update_index,
)

# Image Sequence Embeddings should be fully self-attended without causality
img_causal_mask = ops.concatenate(
[
ops.ones((batch_size, output_length, self.img_sequence_length)),
ops.zeros(
(
batch_size,
output_length,
input_length - self.img_sequence_length,
)
),
],
axis=-1,
)

causal_mask = ops.maximum(causal_mask, img_causal_mask)

return (
ops.minimum(decoder_mask, causal_mask)
if decoder_mask is not None
else causal_mask
)

def get_config(self):
config = super().get_config()
config.update(
{
"img_sequence_length": self.img_sequence_length,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"head_dim": self.head_dim,
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
}
)
return config
Loading