Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions keras_nlp/src/models/stable_diffusion_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
103 changes: 103 additions & 0 deletions keras_nlp/src/models/stable_diffusion_v3/clip_encoder_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras import layers
from keras import ops


def quick_gelu(x):
return x * ops.sigmoid(1.702 * x)


class CLIPEncoderBlock(layers.Layer):
def __init__(
self,
hidden_dim,
num_heads,
intermediate_dim,
intermediate_activation="quick_gelu",
**kwargs,
):
super().__init__(**kwargs)
if hidden_dim % num_heads != 0:
raise ValueError(
"`hidden_dim` must be divisible by `num_heads`. "
f"Received: hidden_dim={hidden_dim}, num_heads={num_heads}"
)
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.intermediate_dim = intermediate_dim
self.intermediate_activation = intermediate_activation

if intermediate_activation == "quick_gelu":
intermediate_activation = quick_gelu

self.layer_norm_1 = layers.LayerNormalization(
epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_1"
)
self.attention = layers.MultiHeadAttention(
num_heads,
hidden_dim // num_heads,
dtype=self.dtype_policy,
name="attention",
)
self.layer_norm_2 = layers.LayerNormalization(
epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_2"
)
self.dense_1 = layers.Dense(
self.intermediate_dim, dtype=self.dtype_policy, name="dense_1"
)
self.activation = layers.Activation(
intermediate_activation, dtype=self.dtype_policy, name="activation"
)
self.dense_2 = layers.Dense(
self.hidden_dim, dtype=self.dtype_policy, name="dense_2"
)

def build(self, input_shape):
self.layer_norm_1.build(input_shape)
self.attention.build(input_shape, input_shape, input_shape)
self.layer_norm_2.build(input_shape)
self.dense_1.build(input_shape)
input_shape = self.dense_1.compute_output_shape(input_shape)
self.dense_2.build(input_shape)

def compute_output_shape(self, inputs_shape):
outputs_shape = list(inputs_shape)
outputs_shape[-1] = self.hidden_dim
return outputs_shape

def call(self, x, training=None):
residual = x
x = self.layer_norm_1(x)
x = self.attention(x, x, x, training=training, use_causal_mask=True)
x = ops.add(residual, x)

residual = x
x = self.dense_1(self.layer_norm_2(residual))
x = self.activation(x)
x = self.dense_2(x)
x = ops.add(residual, x)
return x

def get_config(self):
config = super().get_config()
config.update(
{
"hidden_dim": self.hidden_dim,
"num_heads": self.num_heads,
"intermediate_dim": self.intermediate_dim,
"intermediate_activation": self.intermediate_activation,
}
)
return config
104 changes: 104 additions & 0 deletions keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras

from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.models.stable_diffusion_v3.clip_tokenizer import (
CLIPTokenizer,
)
from keras_nlp.src.utils.keras_utils import (
convert_inputs_to_list_of_tensor_segments,
)

try:
import tensorflow as tf
except ImportError:
tf = None


class CLIPPreprocessor(Preprocessor):
tokenizer_cls = CLIPTokenizer

def __init__(
self,
tokenizer,
sequence_length=77,
add_start_token=True,
add_end_token=False,
to_lower=True,
pad_with_end_token=True,
**kwargs,
):
super().__init__(**kwargs)
self.tokenizer = tokenizer
self.sequence_length = sequence_length
self.add_start_token = add_start_token
self.add_end_token = add_end_token
self.to_lower = to_lower
self.pad_with_end_token = pad_with_end_token

def build(self, input_shape):
# Defer packer creation to `build()` so that we can be sure tokenizer
# assets have loaded when restoring a saved model.
pad_value = self.tokenizer.pad_token_id
if self.pad_with_end_token:
pad_value = self.tokenizer.end_token_id

self.packer = StartEndPacker(
start_value=self.tokenizer.start_token_id,
end_value=self.tokenizer.end_token_id,
pad_value=pad_value,
sequence_length=self.sequence_length,
return_padding_mask=True,
)
self.built = True

# TODO: Use `@tf_preprocessing_function` after rebasing.
def call(self, x, y=None, sample_weight=None, sequence_length=None):
x = convert_inputs_to_list_of_tensor_segments(x)
if len(x) != 1:
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side note, I've cleaned this up on a commit now on master, this will change slightly when i rebase the whole branch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a TODO for this call

"T5XXL requires each input feature to contain only "
f"one segment, but received {len(x)}. If you are using T5XXL"
" for a multi-segment classification task, please refer to "
"classification models like BERT or RoBERTa."
)
if self.to_lower:
x = tf.strings.lower(x)
sequence_length = sequence_length or self.sequence_length
token_ids, padding_mask = self.packer(
self.tokenizer(x[0]),
sequence_length=sequence_length,
add_start_value=self.add_start_token,
add_end_value=self.add_end_token,
)
x = {
"token_ids": token_ids,
"padding_mask": padding_mask,
}
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)

def get_config(self):
config = super().get_config()
config.update(
{
"sequence_length": self.sequence_length,
"add_start_token": self.add_start_token,
"add_end_token": self.add_end_token,
"to_lower": self.to_lower,
"pad_with_end_token": self.pad_with_end_token,
}
)
return config
78 changes: 78 additions & 0 deletions keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from keras_nlp.src.models.stable_diffusion_v3.clip_preprocessor import (
CLIPPreprocessor,
)
from keras_nlp.src.models.stable_diffusion_v3.clip_tokenizer import (
CLIPTokenizer,
)
from keras_nlp.src.tests.test_case import TestCase


class CLIPPreprocessorTest(TestCase):
def setUp(self):
vocab = ["air", "plane</w>", "port</w>"]
vocab += ["<|endoftext|>", "<|startoftext|>"]
vocab = dict([(token, i + 1) for i, token in enumerate(vocab)])
merges = ["a i", "p l", "n e</w>", "p o", "r t</w>", "ai r", "pl a"]
merges += ["po rt</w>", "pla ne</w>"]
self.tokenizer = CLIPTokenizer(vocabulary=vocab, merges=merges)
self.init_kwargs = {
"tokenizer": self.tokenizer,
"sequence_length": 8,
}
self.input_data = [" airplane airport"]

def test_preprocessor_basics(self):
self.run_preprocessing_layer_test(
cls=CLIPPreprocessor,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output={
"token_ids": [[5, 1, 2, 1, 3, 4, 4, 4]],
"padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]],
},
)

def test_no_start_end_token(self):
input_data = [" airplane airport"] * 4
preprocessor = CLIPPreprocessor(
tokenizer=self.tokenizer,
sequence_length=8,
add_start_token=False,
add_end_token=False,
pad_with_end_token=False,
)
x = preprocessor(input_data)
self.assertAllEqual(x["token_ids"], [[1, 2, 1, 3, 0, 0, 0, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)

def test_sequence_length_override(self):
input_data = " airplane airport"
preprocessor = CLIPPreprocessor(**self.init_kwargs)
x = preprocessor(input_data, sequence_length=4)
self.assertAllEqual(x["token_ids"], [5, 1, 2, 1])

@pytest.mark.kaggle_key_required
@pytest.mark.extra_large
def test_all_presets(self):
self.skipTest("TODO")
for preset in CLIPPreprocessor.presets:
self.run_preset_test(
cls=CLIPPreprocessor,
preset=preset,
input_data=self.input_data,
)
Loading
Loading