Skip to content

Commit ab15819

Browse files
Fix naming
1 parent b1d54ce commit ab15819

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

src/compel/compel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self,
3030
downweight_mode: DownweightMode = DownweightMode.MASK,
3131
use_penultimate_clip_layer: bool=False,
3232
device: Optional[str] = None):
33-
return_pooled: Union[str, List[bool]] = False,
33+
requires_pooled: Union[str, List[bool]] = False,
3434
"""
3535
Initialize Compel. The tokenizer and text_encoder can be lifted directly from any DiffusionPipeline.
3636
@@ -64,7 +64,7 @@ def __init__(self,
6464
padding_attention_mask_value = padding_attention_mask_value,
6565
downweight_mode=downweight_mode,
6666
use_penultimate_clip_layer=use_penultimate_clip_layer,
67-
return_pooled=return_pooled,
67+
requires_pooled=requires_pooled,
6868
)
6969
else:
7070
self.conditioning_provider = EmbeddingsProvider(tokenizer=tokenizer,
@@ -75,7 +75,7 @@ def __init__(self,
7575
padding_attention_mask_value = padding_attention_mask_value,
7676
downweight_mode=downweight_mode,
7777
use_penultimate_clip_layer=use_penultimate_clip_layer,
78-
return_pooled=return_pooled,
78+
requires_pooled=requires_pooled,
7979
)
8080
self._device = device
8181

src/compel/embeddings_provider.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def build_weighted_embedding_tensor(self,
312312
token_ids: torch.Tensor,
313313
per_token_weights: torch.Tensor,
314314
attention_mask: Optional[torch.Tensor] = None,
315-
return_pooled: bool = False,
315+
should_return_pooled: bool = False,
316316
device: Optional[str] = None) -> torch.Tensor:
317317
"""
318318
:param token_ids: A tensor of shape `n*[self.max_length]` containing token IDs (ints) where n is some arbitrary
@@ -373,7 +373,7 @@ def build_weighted_embedding_tensor(self,
373373

374374
chunk_start_index += chunk_size
375375

376-
if self.requires_pooled:
376+
if should_return_pooled:
377377
return weighted_z, pooled
378378

379379
return weighted_z
@@ -473,15 +473,15 @@ def __init__(self,
473473
textual_inversion_manager: BaseTextualInversionManager = None,
474474
dtype_for_device_getter: Callable[[torch.device], torch.dtype] = lambda device: torch.float32,
475475
hidden_states_types: Union[str, List[str]] = "final",
476-
return_pooled: Union[str, List[bool]] = False,
476+
requires_pooled: Union[str, List[bool]] = False,
477477
):
478478

479479
hidden_states_types = len(text_encoders) * [hidden_states_types] if not isinstance(hidden_states_types, (list, tuple)) else hidden_states_types
480-
return_pooled = len(text_encoders) * [return_pooled] if not isinstance(return_pooled, (list, tuple)) else return_pooled
480+
requires_pooled = len(text_encoders) * [requires_pooled] if not isinstance(requires_pooled, (list, tuple)) else requires_pooled
481481

482482
self.embedding_providers = [
483483
EmbeddingsProvider(tokenizer, text_encoder, textual_inversion_manager, dtype_for_device_getter, hidden_states_type, pooled)
484-
for tokenizer, text_encoder, hidden_states_type, pooled in zip(tokenizers, text_encoders, hidden_states_types, return_pooled)
484+
for tokenizer, text_encoder, hidden_states_type, pooled in zip(tokenizers, text_encoders, hidden_states_types, requires_pooled)
485485
]
486486

487487
@property

test/test_compel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_basic_prompt_multi_text_encoder(self):
8080
tokenizer_2 = DummyTokenizer()
8181
text_encoder_2 = DummyTransformer()
8282

83-
compel = Compel(tokenizer=[tokenizer_1, tokenizer_2], text_encoder=[text_encoder_1, text_encoder_2], hidden_states_type="penultimate", return_pooled=[False, True])
83+
compel = Compel(tokenizer=[tokenizer_1, tokenizer_2], text_encoder=[text_encoder_1, text_encoder_2], hidden_states_type="penultimate", requires_pooled=[False, True])
8484

8585
# test "a b c" makes it to the Conditioning intact for t=0, t=0.5, t=1
8686
prompt = " ".join(KNOWN_WORDS[:3])

0 commit comments

Comments
 (0)