Skip to content

Commit 7b9ad52

Browse files
committed
Merge branch 'master' into dr-support-pip-cm
2 parents a58c4fb + 7d61033 commit 7b9ad52

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
221221

222222
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
223223

224-
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0```
224+
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
225225

226226

227227
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.

comfy/sd1_clip.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,14 +460,15 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
460460
return embed_out
461461

462462
class SDTokenizer:
463-
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
463+
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
464464
if tokenizer_path is None:
465465
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
466466
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
467467
self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
468468
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
469469
self.end_token = None
470470
self.min_padding = min_padding
471+
self.pad_left = pad_left
471472

472473
empty = self.tokenizer('')["input_ids"]
473474
self.tokenizer_adds_end_token = has_end_token
@@ -522,6 +523,12 @@ def _try_get_embedding(self, embedding_name:str):
522523
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
523524
return (embed, leftover)
524525

526+
def pad_tokens(self, tokens, amount):
527+
if self.pad_left:
528+
for i in range(amount):
529+
tokens.insert(0, (self.pad_token, 1.0, 0))
530+
else:
531+
tokens.extend([(self.pad_token, 1.0, 0)] * amount)
525532

526533
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
527534
'''
@@ -600,7 +607,7 @@ def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_optio
600607
if self.end_token is not None:
601608
batch.append((self.end_token, 1.0, 0))
602609
if self.pad_to_max_length:
603-
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
610+
self.pad_tokens(batch, remaining_length)
604611
#start new batch
605612
batch = []
606613
if self.start_token is not None:
@@ -614,11 +621,11 @@ def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_optio
614621
if self.end_token is not None:
615622
batch.append((self.end_token, 1.0, 0))
616623
if min_padding is not None:
617-
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
624+
self.pad_tokens(batch, min_padding)
618625
if self.pad_to_max_length and len(batch) < self.max_length:
619-
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
626+
self.pad_tokens(batch, self.max_length - len(batch))
620627
if min_length is not None and len(batch) < min_length:
621-
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
628+
self.pad_tokens(batch, min_length - len(batch))
622629

623630
if not return_word_ids:
624631
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]

0 commit comments

Comments
 (0)