Skip to content
Prev Previous commit
Next Next commit
Refactor control tensor handling and update main channels calculation…
… in QwenImageEditPlus2509Model
  • Loading branch information
Julien Blanchon committed Sep 24, 2025
commit 7eb7513c0a1b0f0fc4fae29924de16c27b03cf32
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,15 @@ def condition_noisy_latents(
control_tensors = []

# Primary control tensor
control_tensor = batch.control_tensor
if control_tensor is not None:
control_tensors.append(control_tensor)
if batch.control_tensor is not None:
control_tensors.append(batch.control_tensor)

# Secondary control tensor
if hasattr(batch, "control_tensor2") and batch.control_tensor2 is not None:
if batch.control_tensor2 is not None:
control_tensors.append(batch.control_tensor2)

# Third control tensor
if hasattr(batch, "control_tensor3") and batch.control_tensor3 is not None:
if batch.control_tensor3 is not None:
control_tensors.append(batch.control_tensor3)

if not control_tensors:
Expand All @@ -189,6 +188,7 @@ def condition_noisy_latents(
target_h = batch.file_items[0].crop_height
target_w = batch.file_items[0].crop_width

# Ensure control image matches the main image dimensions so VAE latents align
if (
control_tensor.shape[2] != target_h
or control_tensor.shape[3] != target_w
Expand Down Expand Up @@ -271,7 +271,7 @@ def get_noise_prediction(
# The latents are packed as [main_latents, control1_latents, control2_latents, ...]

# Split the latents - first part is main latents, rest are control latents
main_channels = 16 # Qwen Image uses 16 channels for main latents
main_channels = self.transformer.config.in_channels // 4
main_latents = latent_model_input[:, :main_channels, :, :]

# The rest are control latents - we need to split them by number of control images
Expand Down Expand Up @@ -333,9 +333,9 @@ def get_noise_prediction(
)
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist()
enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype)
prompt_embeds_mask = text_embeddings.attention_mask.to(
self.device_torch, dtype=torch.int64
)
# prompt_embeds_mask = text_embeddings.attention_mask.to(
# self.device_torch, dtype=torch.int64
# )

noise_pred = self.transformer(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
Expand Down