Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions src/compel/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import cross_attention_control
from .conditioning_scheduler import ConditioningScheduler, StaticConditioningScheduler
from .embeddings_provider import EmbeddingsProvider, BaseTextualInversionManager, DownweightMode
from .embeddings_provider import EmbeddingsProvider, BaseTextualInversionManager, DownweightMode, EmbeddingsProviderMulti
from .prompt_parser import Blend, FlattenedPrompt, PromptParser, CrossAttentionControlSubstitute, Conjunction

__all__ = ["Compel", "DownweightMode"]
Expand All @@ -21,15 +21,17 @@ class Compel:


def __init__(self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
textual_inversion_manager: Optional[BaseTextualInversionManager] = None,
dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32,
truncate_long_prompts: bool = True,
padding_attention_mask_value: int = 1,
downweight_mode: DownweightMode = DownweightMode.MASK,
use_penultimate_clip_layer: bool=False,
device: Optional[str] = None):
use_penultimate_clip_layer: Union[bool, List[bool]]=False,
use_penultimate_layer_norm: Union[bool, List[bool]]=False,
device: Optional[str] = None
):
"""
Initialize Compel. The tokenizer and text_encoder can be lifted directly from any DiffusionPipeline.

Expand All @@ -50,16 +52,33 @@ def __init__(self,
`device`: The torch device on which the tensors should be created. If a device is not specified, the device will
be the same as that of the `text_encoder` at the moment when `build_conditioning_tensor()` is called.
"""
self.conditioning_provider = EmbeddingsProvider(tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=textual_inversion_manager,
dtype_for_device_getter=dtype_for_device_getter,
truncate=truncate_long_prompts,
padding_attention_mask_value = padding_attention_mask_value,
downweight_mode=downweight_mode,
use_penultimate_clip_layer=use_penultimate_clip_layer
)
self._device = device
if isinstance(tokenizer, (tuple, list)) and not isinstance(text_encoder, (tuple, list)):
raise ValueError("Cannot provide list of tokenizers, but not of text encoders.")
elif not isinstance(tokenizer, (tuple, list)) and isinstance(text_encoder, (tuple, list)):
raise ValueError("Cannot provide list of text encoders, but not of tokenizers.")
elif isinstance(tokenizer, (tuple, list)) and isinstance(text_encoder, (tuple, list)):
self.conditioning_provider = EmbeddingsProviderMulti(tokenizers=tokenizer,
text_encoders=text_encoder,
textual_inversion_manager=textual_inversion_manager,
dtype_for_device_getter=dtype_for_device_getter,
truncate=truncate_long_prompts,
padding_attention_mask_value = padding_attention_mask_value,
downweight_mode=downweight_mode,
use_penultimate_clip_layer=use_penultimate_clip_layer,
use_penultimate_layer_norm=use_penultimate_layer_norm,
)
else:
self.conditioning_provider = EmbeddingsProvider(tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=textual_inversion_manager,
dtype_for_device_getter=dtype_for_device_getter,
truncate=truncate_long_prompts,
padding_attention_mask_value = padding_attention_mask_value,
downweight_mode=downweight_mode,
use_penultimate_clip_layer=use_penultimate_clip_layer,
use_penultimate_layer_norm=use_penultimate_layer_norm,
)
self._device = device

@property
def device(self):
Expand Down
104 changes: 97 additions & 7 deletions src/compel/embeddings_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Callable, Union, Tuple, List, Optional

import torch
from transformers import CLIPTokenizer, CLIPTextModel
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from typing import List, Tuple

__all__ = ["EmbeddingsProvider", "DownweightMode"]

Expand All @@ -22,13 +23,14 @@ class EmbeddingsProvider:

def __init__(self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], # convert a list of int token ids to a tensor of embeddings
textual_inversion_manager: BaseTextualInversionManager = None,
dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32,
truncate: bool = True,
padding_attention_mask_value: int = 1,
downweight_mode: DownweightMode = DownweightMode.MASK,
use_penultimate_clip_layer: bool=False
use_penultimate_clip_layer: bool=False,
use_penultimate_layer_norm: bool=True,
):
"""
`tokenizer`: converts strings to lists of int token ids
Expand All @@ -50,6 +52,7 @@ def __init__(self,
self.padding_attention_mask_value = padding_attention_mask_value
self.downweight_mode = downweight_mode
self.use_penultimate_clip_layer = use_penultimate_clip_layer
self.use_penultimate_layer_norm = use_penultimate_layer_norm

# by default always use float32
self.get_dtype_for_device = dtype_for_device_getter
Expand Down Expand Up @@ -183,7 +186,7 @@ def get_embeddings_for_weighted_prompt_fragments(self,
else:
return batch_z

def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = True) -> List[List[int]]:
def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = True, padding: str = 'do_not_pad') -> List[List[int]]:
"""
Convert a list of strings like `["a cat", "a dog", "monkey riding a bicycle"]` into a list of lists of token
ids like `[[bos, 0, 1, eos], [bos, 0, 2, eos], [bos, 3, 4, 0, 5, eos]]`. bos/eos markers are skipped if
Expand All @@ -200,7 +203,7 @@ def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool =
token_ids_list = self.tokenizer(
texts,
truncation=self.truncate_to_model_max_length,
padding='do_not_pad',
padding=padding,
return_tensors=None, # just give me lists of ints
)['input_ids']

Expand All @@ -220,6 +223,16 @@ def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool =

return result

def maybe_get_pooled(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None) -> Optional[torch.Tensor]:
if not self.requires_pooled:
return None

token_ids = self.get_token_ids(texts, padding="max_length")
text_encoder_output = self.text_encoder(token_ids, attention_mask, return_dict=True)

pooled = text_encoder_output.pooler_output if "pooler_output" in text_encoder_output else text_encoder_output.text_embeds
return pooled

def get_token_ids_and_expand_weights(self, fragments: List[str], weights: List[float], device: str
) -> (torch.Tensor, torch.Tensor, torch.Tensor):
'''
Expand Down Expand Up @@ -305,7 +318,7 @@ def build_weighted_embedding_tensor(self,
device: Optional[str] = None) -> torch.Tensor:
"""
Build a tensor that embeds the passed-in token IDs and applies the given per_token weights

:param token_ids: A tensor of shape `n*[self.max_length]` containing token IDs (ints) where n is some arbitrary
integer (i.e. n==1 for shorter prompts, or it may be >1 if there are more than max_length tokens in the
original prompt)
Expand Down Expand Up @@ -353,6 +366,7 @@ def build_weighted_embedding_tensor(self,
if weighted_z is None
else torch.cat([weighted_z, this_weighted_z], dim=1)
)

chunk_start_index += chunk_size

return weighted_z
Expand All @@ -363,10 +377,14 @@ def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor,
attention_mask,
output_hidden_states=self.use_penultimate_clip_layer,
return_dict=True)

if self.use_penultimate_clip_layer:
# needs normalizing
penultimate_hidden_state = text_encoder_output.hidden_states[-2]
return self.text_encoder.text_model.final_layer_norm(penultimate_hidden_state)

if self.use_penultimate_layer_norm:
penultimate_hidden_state = self.text_encoder.text_model.final_layer_norm(penultimate_hidden_state)
return penultimate_hidden_state
else:
# already normalized
return text_encoder_output.last_hidden_state
Expand Down Expand Up @@ -433,3 +451,75 @@ def _get_token_ranges_for_fragments(self, chunked_and_padded_token_ids: List[int
fragment_start = fragment_end + 1

return corresponding_indices


class EmbeddingsProviderMulti:

def __init__(self,
tokenizers: CLIPTokenizer,
text_encoders: Union[CLIPTextModel, CLIPTextModelWithProjection], # convert a list of int token ids to a tensor of embeddings
textual_inversion_manager: BaseTextualInversionManager = None,
dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32,
truncate: bool = True,
padding_attention_mask_value: int = 1,
downweight_mode: DownweightMode = DownweightMode.MASK,
use_penultimate_clip_layer: List[bool]=False,
use_penultimate_layer_norm: List[bool]=True,
):

use_penultimate_clip_layer = len(text_encoders) * [use_penultimate_clip_layer] if not isinstance(use_penultimate_clip_layer, (list, tuple)) else use_penultimate_clip_layer
use_penultimate_layer_norm = len(text_encoders) * [use_penultimate_layer_norm] if not isinstance(use_penultimate_layer_norm, (list, tuple)) else use_penultimate_layer_norm

self.embedding_providers = [
EmbeddingsProvider(tokenizer, text_encoder, textual_inversion_manager, dtype_for_device_getter, truncate, padding_attention_mask_value, downweight_mode, clip_layer, clip_norm)
for tokenizer, text_encoder, clip_layer, clip_norm in zip(tokenizers, text_encoders, use_penultimate_clip_layer, use_penultimate_layer_norm)
]

@property
def text_encoder(self):
return self.embedding_providers[0].text_encoder

@property
def tokenizer(self):
return self.embedding_providers[0].tokenizer

def get_token_ids(self, *args, **kwargs):
# get token ids does not use padding. The padding ID is the only ID that can differ between tokenizers
# so for simplicity, we just return `get_token_ids` of the first tokenizer
return self.embedding_providers[0].get_token_ids(self, *args, **kwargs)

def maybe_get_pooled(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None) -> Optional[torch.Tensor]:
pooled = [provider.maybe_get_pooled(texts, attention_mask) for provider in self.embedding_providers]
pooled = [p for p in pooled if p is not None]

if len(pooled) == 0:
return None

return torch.cat(pooled, dim=-1)

def get_embeddings_for_weighted_prompt_fragments(self,
text_batch: List[List[str]],
fragment_weights_batch: List[List[float]],
should_return_tokens: bool = False,
device='cpu',
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:

outputs = [provider.get_embeddings_for_weighted_prompt_fragments(text_batch, fragment_weights_batch, should_return_tokens=should_return_tokens, device=device) for provider in self.embedding_providers]

text_embeddings_list = []
tokens = []

for output in outputs:
text_embeddings_list.append(output[0])

if should_return_tokens:
tokens.append(output[1])

text_embeddings = torch.cat(text_embeddings_list, dim=-1)

outputs = (text_embeddings,)

if should_return_tokens:
outputs += (tokens,)

return outputs
13 changes: 8 additions & 5 deletions test/prompting_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def resize_token_embeddings(self, new_size=None):
def get_input_embeddings(self):
return self.embeddings

def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor], return_dict: bool=True):
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor], output_hidden_states: bool=False, return_dict: bool=True):
if input_ids.shape[0] > 1:
raise AssertionError("for unit testing, only batch size =1 is supported")
all_embeddings = torch.cat([e.unsqueeze(0) for e in self.embeddings]).to(self.device)
Expand All @@ -64,8 +64,12 @@ def __init__(self, last_hidden_state):
self.last_hidden_state = last_hidden_state

def __getitem__(self, item):
assert item == 0
return self.last_hidden_state
if item == 0:
return self.last_hidden_state[:, -1, :]
if item == 1:
return self.last_hidden_state
if item == 2:
return 2 * [self.last_hidden_state]

@property
def hidden_states(self):
Expand All @@ -75,15 +79,14 @@ def hidden_states(self):
return o

def __call__(self, input_ids, attention_mask=None, **kwargs):
return self.forward(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
return self.forward(input_ids=input_ids, attention_mask=attention_mask, return_dict=True, output_hidden_states=kwargs.pop("output_hidden_states", False))

@property
def text_model(self):
tm = Mock()
tm.final_layer_norm = nn.LayerNorm(normalized_shape=[self.text_model_max_length, self.embedding_length])
return tm


class DummyTokenizer():
def __init__(self, model_max_length=77):
self.tokens = KNOWN_WORDS.copy() + ["<|bos|>", "<|pad|>", "<|eos|>"]
Expand Down
15 changes: 15 additions & 0 deletions test/test_compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,21 @@ def test_basic_prompt(self):
conditioning,
atol=1e-6))

def test_basic_prompt_multi_text_encoder(self):
tokenizer_1 = DummyTokenizer()
text_encoder_1 = DummyTransformer()

tokenizer_2 = DummyTokenizer()
text_encoder_2 = DummyTransformer()

compel = Compel(tokenizer=[tokenizer_1, tokenizer_2], text_encoder=[text_encoder_1, text_encoder_2], hidden_states_type="penultimate", requires_pooled=[False, True])

# test "a b c" makes it to the Conditioning intact for t=0, t=0.5, t=1
prompt = " ".join(KNOWN_WORDS[:3])
output = compel(prompt)

assert output.shape == (1, 77, 2 * 768)


def test_basic_negative_prompt(self):
tokenizer = DummyTokenizer()
Expand Down