Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions src/compel/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import cross_attention_control
from .conditioning_scheduler import ConditioningScheduler, StaticConditioningScheduler
from .embeddings_provider import EmbeddingsProvider, BaseTextualInversionManager, DownweightMode
from .embeddings_provider import EmbeddingsProvider, BaseTextualInversionManager, DownweightMode, EmbeddingsProviderMulti
from .prompt_parser import Blend, FlattenedPrompt, PromptParser, CrossAttentionControlSubstitute, Conjunction

__all__ = ["Compel", "DownweightMode"]
Expand All @@ -21,15 +21,18 @@ class Compel:


def __init__(self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
textual_inversion_manager: Optional[BaseTextualInversionManager] = None,
dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32,
truncate_long_prompts: bool = True,
padding_attention_mask_value: int = 1,
downweight_mode: DownweightMode = DownweightMode.MASK,
use_penultimate_clip_layer: bool=False,
device: Optional[str] = None):
use_penultimate_clip_layer: Union[bool, List[bool]]=False,
use_penultimate_layer_norm: Union[bool, List[bool]]=False,
requires_pooled: Union[bool, List[bool]]=False,
device: Optional[str] = None
):
"""
Initialize Compel. The tokenizer and text_encoder can be lifted directly from any DiffusionPipeline.

Expand All @@ -50,22 +53,42 @@ def __init__(self,
`device`: The torch device on which the tensors should be created. If a device is not specified, the device will
be the same as that of the `text_encoder` at the moment when `build_conditioning_tensor()` is called.
"""
self.conditioning_provider = EmbeddingsProvider(tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=textual_inversion_manager,
dtype_for_device_getter=dtype_for_device_getter,
truncate=truncate_long_prompts,
padding_attention_mask_value = padding_attention_mask_value,
downweight_mode=downweight_mode,
use_penultimate_clip_layer=use_penultimate_clip_layer
)
if isinstance(tokenizer, (tuple, list)) and not isinstance(text_encoder, (tuple, list)):
raise ValueError("Cannot provide list of tokenizers, but not of text encoders.")
elif not isinstance(tokenizer, (tuple, list)) and isinstance(text_encoder, (tuple, list)):
raise ValueError("Cannot provide list of text encoders, but not of tokenizers.")
elif isinstance(tokenizer, (tuple, list)) and isinstance(text_encoder, (tuple, list)):
self.conditioning_provider = EmbeddingsProviderMulti(tokenizers=tokenizer,
text_encoders=text_encoder,
textual_inversion_manager=textual_inversion_manager,
dtype_for_device_getter=dtype_for_device_getter,
truncate=truncate_long_prompts,
padding_attention_mask_value = padding_attention_mask_value,
downweight_mode=downweight_mode,
use_penultimate_clip_layer=use_penultimate_clip_layer,
use_penultimate_layer_norm=use_penultimate_layer_norm,
requires_pooled=requires_pooled,
)
else:
self.conditioning_provider = EmbeddingsProvider(tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=textual_inversion_manager,
dtype_for_device_getter=dtype_for_device_getter,
truncate=truncate_long_prompts,
padding_attention_mask_value = padding_attention_mask_value,
downweight_mode=downweight_mode,
use_penultimate_clip_layer=use_penultimate_clip_layer,
use_penultimate_layer_norm=use_penultimate_layer_norm,
requires_pooled=requires_pooled,
)

self._device = device

@property
def device(self):
return self._device if self._device else self.conditioning_provider.text_encoder.device

def make_conditioning_scheduler(self, positive_prompt: str, negative_prompt: str='') -> ConditioningScheduler:
def make_conditioning_scheduler(self, positive_prompt: str, negative_prompt: str='') -> ConditioningScheduler:
"""
Return a ConditioningScheduler object that provides conditioning tensors for different diffusion steps (currently
not fully implemented).
Expand All @@ -78,13 +101,19 @@ def make_conditioning_scheduler(self, positive_prompt: str, negative_prompt: str
return StaticConditioningScheduler(positive_conditioning=positive_conditioning,
negative_conditioning=negative_conditioning)

def build_conditioning_tensor(self, text: str) -> torch.Tensor:
def build_conditioning_tensor(self, text: str, return_pooled: bool = False) -> torch.Tensor:
"""
Build a conditioning tensor by parsing the text for Compel syntax, constructing a Conjunction, and then
building a conditioning tensor from that Conjunction.
"""
conjunction = self.parse_prompt_string(text)
conditioning, _ = self.build_conditioning_tensor_for_conjunction(conjunction)

pooled = self.conditioning_provider.maybe_get_pooled([text])

if return_pooled and pooled is not None:
return conditioning, pooled

return conditioning

@torch.no_grad()
Expand All @@ -100,12 +129,22 @@ def __call__(self, text: Union[str, List[str]]) -> torch.FloatTensor:
text = [text]

cond_tensor = []
pooled = []
for text_input in text:
cond_tensor.append(self.build_conditioning_tensor(text_input))
output = self.build_conditioning_tensor(text_input, return_pooled=True)

requires_pooled = len(output) > 1
cond_tensor.append(output[0] if requires_pooled else output)

if requires_pooled:
pooled.append(output[1])

cond_tensor = self.pad_conditioning_tensors_to_same_length(conditionings=cond_tensor)
cond_tensor = torch.cat(cond_tensor)

if len(pooled) > 0:
return cond_tensor, torch.cat(pooled)

return cond_tensor

@classmethod
Expand Down
108 changes: 101 additions & 7 deletions src/compel/embeddings_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Callable, Union, Tuple, List, Optional

import torch
from transformers import CLIPTokenizer, CLIPTextModel
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from typing import List, Tuple

__all__ = ["EmbeddingsProvider", "DownweightMode"]

Expand All @@ -22,13 +23,15 @@ class EmbeddingsProvider:

def __init__(self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], # convert a list of int token ids to a tensor of embeddings
textual_inversion_manager: BaseTextualInversionManager = None,
dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32,
truncate: bool = True,
padding_attention_mask_value: int = 1,
downweight_mode: DownweightMode = DownweightMode.MASK,
use_penultimate_clip_layer: bool=False
use_penultimate_clip_layer: bool=False,
use_penultimate_layer_norm: bool=True,
requires_pooled: bool=False,
):
"""
`tokenizer`: converts strings to lists of int token ids
Expand All @@ -50,6 +53,8 @@ def __init__(self,
self.padding_attention_mask_value = padding_attention_mask_value
self.downweight_mode = downweight_mode
self.use_penultimate_clip_layer = use_penultimate_clip_layer
self.use_penultimate_layer_norm = use_penultimate_layer_norm
self.requires_pooled = requires_pooled

# by default always use float32
self.get_dtype_for_device = dtype_for_device_getter
Expand Down Expand Up @@ -183,7 +188,7 @@ def get_embeddings_for_weighted_prompt_fragments(self,
else:
return batch_z

def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = True) -> List[List[int]]:
def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = True, padding: str = 'do_not_pad') -> List[List[int]]:
"""
Convert a list of strings like `["a cat", "a dog", "monkey riding a bicycle"]` into a list of lists of token
ids like `[[bos, 0, 1, eos], [bos, 0, 2, eos], [bos, 3, 4, 0, 5, eos]]`. bos/eos markers are skipped if
Expand All @@ -200,7 +205,7 @@ def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool =
token_ids_list = self.tokenizer(
texts,
truncation=self.truncate_to_model_max_length,
padding='do_not_pad',
padding=padding,
return_tensors=None, # just give me lists of ints
)['input_ids']

Expand All @@ -220,6 +225,18 @@ def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool =

return result

def maybe_get_pooled(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None) -> Optional[torch.Tensor]:
if not self.requires_pooled:
return None

token_ids = self.get_token_ids(texts, padding="max_length")
token_ids = torch.tensor(token_ids, dtype=torch.long).to(self.text_encoder.device)

text_encoder_output = self.text_encoder(token_ids, attention_mask, return_dict=True)
pooled = text_encoder_output.text_embeds

return pooled

def get_token_ids_and_expand_weights(self, fragments: List[str], weights: List[float], device: str
) -> (torch.Tensor, torch.Tensor, torch.Tensor):
'''
Expand Down Expand Up @@ -305,7 +322,7 @@ def build_weighted_embedding_tensor(self,
device: Optional[str] = None) -> torch.Tensor:
"""
Build a tensor that embeds the passed-in token IDs and applies the given per_token weights

:param token_ids: A tensor of shape `n*[self.max_length]` containing token IDs (ints) where n is some arbitrary
integer (i.e. n==1 for shorter prompts, or it may be >1 if there are more than max_length tokens in the
original prompt)
Expand Down Expand Up @@ -366,7 +383,10 @@ def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor,
if self.use_penultimate_clip_layer:
# needs normalizing
penultimate_hidden_state = text_encoder_output.hidden_states[-2]
return self.text_encoder.text_model.final_layer_norm(penultimate_hidden_state)

if self.use_penultimate_layer_norm:
penultimate_hidden_state = self.text_encoder.text_model.final_layer_norm(penultimate_hidden_state)
return penultimate_hidden_state
else:
# already normalized
return text_encoder_output.last_hidden_state
Expand Down Expand Up @@ -433,3 +453,77 @@ def _get_token_ranges_for_fragments(self, chunked_and_padded_token_ids: List[int
fragment_start = fragment_end + 1

return corresponding_indices


class EmbeddingsProviderMulti:

def __init__(self,
tokenizers: CLIPTokenizer,
text_encoders: Union[CLIPTextModel, CLIPTextModelWithProjection], # convert a list of int token ids to a tensor of embeddings
textual_inversion_manager: BaseTextualInversionManager = None,
dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32,
truncate: bool = True,
padding_attention_mask_value: int = 1,
downweight_mode: DownweightMode = DownweightMode.MASK,
use_penultimate_clip_layer: Union[List[bool], bool]=False,
use_penultimate_layer_norm: Union[List[bool], bool]=True,
requires_pooled: Union[List[bool], bool]=False,
):

use_penultimate_clip_layer = len(text_encoders) * [use_penultimate_clip_layer] if not isinstance(use_penultimate_clip_layer, (list, tuple)) else use_penultimate_clip_layer
use_penultimate_layer_norm = len(text_encoders) * [use_penultimate_layer_norm] if not isinstance(use_penultimate_layer_norm, (list, tuple)) else use_penultimate_layer_norm
requires_pooled = len(text_encoders) * [requires_pooled] if not isinstance(requires_pooled, (list, tuple)) else requires_pooled

self.embedding_providers = [
EmbeddingsProvider(tokenizer, text_encoder, textual_inversion_manager, dtype_for_device_getter, truncate, padding_attention_mask_value, downweight_mode, clip_layer, clip_norm, pooled)
for tokenizer, text_encoder, clip_layer, clip_norm, pooled in zip(tokenizers, text_encoders, use_penultimate_clip_layer, use_penultimate_layer_norm, requires_pooled)
]

@property
def text_encoder(self):
return self.embedding_providers[0].text_encoder

@property
def tokenizer(self):
return self.embedding_providers[0].tokenizer

def get_token_ids(self, *args, **kwargs):
# get token ids does not use padding. The padding ID is the only ID that can differ between tokenizers
# so for simplicity, we just return `get_token_ids` of the first tokenizer
return self.embedding_providers[0].get_token_ids(self, *args, **kwargs)

def maybe_get_pooled(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None) -> Optional[torch.Tensor]:
pooled = [provider.maybe_get_pooled(texts, attention_mask) for provider in self.embedding_providers]
pooled = [p for p in pooled if p is not None]

if len(pooled) == 0:
return None

return torch.cat(pooled, dim=-1)

def get_embeddings_for_weighted_prompt_fragments(self,
text_batch: List[List[str]],
fragment_weights_batch: List[List[float]],
should_return_tokens: bool = False,
device='cpu',
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:

outputs = [provider.get_embeddings_for_weighted_prompt_fragments(text_batch, fragment_weights_batch, should_return_tokens=should_return_tokens, device=device) for provider in self.embedding_providers]

text_embeddings_list = []
tokens = []

for output in outputs:
text_embeddings_list.append(output[0])

if should_return_tokens:
tokens.append(output[1])

text_embeddings = torch.cat(text_embeddings_list, dim=-1)

outputs = (text_embeddings,)

if should_return_tokens:
outputs += (tokens,)

return outputs
16 changes: 12 additions & 4 deletions test/prompting_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def resize_token_embeddings(self, new_size=None):
def get_input_embeddings(self):
return self.embeddings

def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor], return_dict: bool=True):
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor], output_hidden_states: bool=False, return_dict: bool=True):
if input_ids.shape[0] > 1:
raise AssertionError("for unit testing, only batch size =1 is supported")
all_embeddings = torch.cat([e.unsqueeze(0) for e in self.embeddings]).to(self.device)
Expand All @@ -71,11 +71,15 @@ def __getitem__(self, item):
def hidden_states(self):
return [-self.last_hidden_state, self.last_hidden_state]

@property
def text_embeds(self):
return self.last_hidden_state[:, -1, :]

o = EmbeddingsObject(embeddings)
return o

def __call__(self, input_ids, attention_mask=None, **kwargs):
return self.forward(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
return self.forward(input_ids=input_ids, attention_mask=attention_mask, return_dict=True, output_hidden_states=kwargs.pop("output_hidden_states", False))

@property
def text_model(self):
Expand Down Expand Up @@ -104,8 +108,12 @@ def __call__(self, fragments, **kwargs):
else x
for x in tokenized]
padding_strategy = kwargs.get('padding', 'do_not_pad')
if padding_strategy != 'do_not_pad':
raise Exception(f"for unit tests only 'do_not_pad' is supported as a padding strategy (got '{padding_strategy}')")
if padding_strategy not in ['do_not_pad', 'max_length']:
raise Exception(f"for unit tests only 'do_not_pad' and 'max_length' is supported as a padding strategy (got '{padding_strategy}')")

if padding_strategy == "max_length":
tokenized = [(tokens[:-1] + (self.model_max_length - len(tokens)) * [self.pad_token_id] + tokens[1:]) for tokens in tokenized]

return {'input_ids': tokenized}

def convert_tokens_to_ids(self, token_str):
Expand Down
16 changes: 16 additions & 0 deletions test/test_compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,22 @@ def test_basic_prompt(self):
conditioning,
atol=1e-6))

def test_basic_prompt_multi_text_encoder(self):
tokenizer_1 = DummyTokenizer()
text_encoder_1 = DummyTransformer()

tokenizer_2 = DummyTokenizer()
text_encoder_2 = DummyTransformer()

compel = Compel(tokenizer=[tokenizer_1, tokenizer_2], text_encoder=[text_encoder_1, text_encoder_2], use_penultimate_clip_layer=True, use_penultimate_layer_norm=False, requires_pooled=[False, True])

# test "a b c" makes it to the Conditioning intact for t=0, t=0.5, t=1
prompt = " ".join(KNOWN_WORDS[:3])
output, pooled = compel(prompt)

assert output.shape == (1, 77, 2 * 768)
assert pooled.shape == (1, 768)


def test_basic_negative_prompt(self):
tokenizer = DummyTokenizer()
Expand Down