Skip to content

Commit 16359ab

Browse files
committed
Merge branch 'master' into dr-support-pip-cm
2 parents 8f492d8 + 7f374e4 commit 16359ab

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

comfy/ldm/lumina/model.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def patchify_and_embed(
522522
max_cap_len = max(l_effective_cap_len)
523523
max_img_len = max(l_effective_img_len)
524524

525-
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
525+
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
526526

527527
for i in range(bsz):
528528
cap_len = l_effective_cap_len[i]
@@ -531,10 +531,22 @@ def patchify_and_embed(
531531
H_tokens, W_tokens = H // pH, W // pW
532532
assert H_tokens * W_tokens == img_len
533533

534-
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
534+
rope_options = transformer_options.get("rope_options", None)
535+
h_scale = 1.0
536+
w_scale = 1.0
537+
h_start = 0
538+
w_start = 0
539+
if rope_options is not None:
540+
h_scale = rope_options.get("scale_y", 1.0)
541+
w_scale = rope_options.get("scale_x", 1.0)
542+
543+
h_start = rope_options.get("shift_y", 0.0)
544+
w_start = rope_options.get("shift_x", 0.0)
545+
546+
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
535547
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
536-
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
537-
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
548+
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
549+
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
538550
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
539551
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
540552

0 commit comments

Comments
 (0)