diff --git a/config/examples/temp_examples_qwen_edit_plus/qwen_image_edit_plus_2controls.yaml b/config/examples/temp_examples_qwen_edit_plus/qwen_image_edit_plus_2controls.yaml new file mode 100644 index 000000000..07d449d36 --- /dev/null +++ b/config/examples/temp_examples_qwen_edit_plus/qwen_image_edit_plus_2controls.yaml @@ -0,0 +1,93 @@ +--- +job: "extension" +config: + name: "my_first_qwen_image_edit_plus2509_lora_v3" + process: + - type: "sd_trainer" + training_folder: "/home/ubuntu/ai-toolkit/output" + device: "cuda:0" + network: + type: "lora" + linear: 4 + linear_alpha: 4 + save: + dtype: "float16" + save_every: 500 + max_step_saves_to_keep: 12 + datasets: + - folder_path: "/home/ubuntu/ai-toolkit/datasets/end" + control_path: "/home/ubuntu/ai-toolkit/datasets/start" + control_path2: "/home/ubuntu/ai-toolkit/datasets/reference" + full_size_control_images: true + caption_ext: "txt" + caption_dropout_rate: 0.05 + buckets: true + cache_latents_to_disk: true + keep_native_size: true + random_scale: false + random_crop: false + square_crop: false + scale: 0.1 + train: + batch_size: 1 + unload_text_encoder: false + cache_text_embeddings: true + steps: 3000 + gradient_accumulation: 1 + timestep_type: "weighted" + train_unet: true + train_text_encoder: false + gradient_checkpointing: false + noise_scheduler: "flowmatch" + optimizer: "adamw8bit" + lr: 0.0001 + dtype: "bf16" + model: + name_or_path: "Qwen/Qwen-Image-Edit-2509" + arch: "qwen_image_edit_plus2509" + quantize: true + qtype: "qfloat8" + quantize_te: true + qtype_te: "qfloat8" + low_vram: true + sample: + sampler: "flowmatch" + sample_every: 500 + width: 1024 + height: 1024 + samples: + - prompt: "Swap the furniture inside the red bounding box of the Image 1 by the reference furniture of the Image 2" + ctrl_img: "/home/ubuntu/ai-toolkit/datasets/start_full/94a31bd5-b0ce-42bc-a12e-a76a676c1314_2_to_1.png" + ctrl_img2: "/home/ubuntu/ai-toolkit/datasets/reference_full/94a31bd5-b0ce-42bc-a12e-a76a676c1314_2_to_1.png" + width: 1024 + height: 535 + # - prompt: "Swap the furniture inside the red bounding box of the Image 1 by the reference furniture of the Image 2" + # width: 796 + # height: 1024 + # ctrl_img: "/home/ubuntu/ai-toolkit/datasets/start_full/95dabfac-acf1-43cb-a85a-aa795bc52739_1_to_2.png" + # ctrl_img2: "/home/ubuntu/ai-toolkit/datasets/reference_full/95dabfac-acf1-43cb-a85a-aa795bc52739_1_to_2.png" + # - prompt: "Swap the furniture inside the red bounding box of the Image 1 by the reference furniture of the Image 2" + # width: 961 + # height: 1024 + # ctrl_img: "/home/ubuntu/ai-toolkit/datasets/start_full/a3912256-a57f-4254-89df-fd06cdfe1c3c_2_to_1.png" + # ctrl_img2: "/home/ubuntu/ai-toolkit/datasets/reference_full/a3912256-a57f-4254-89df-fd06cdfe1c3c_2_to_1.png" + # - prompt: "Swap the furniture inside the red bounding box of the Image 1 by the reference furniture of the Image 2" + # width: 1024 + # height: 1024 + # ctrl_img: "/home/ubuntu/ai-toolkit/datasets/start_full/a71024fc-89bc-4f57-a1da-50a4b0c4ef3b_2_to_1.png" + # ctrl_img2: "/home/ubuntu/ai-toolkit/datasets/reference_full/a71024fc-89bc-4f57-a1da-50a4b0c4ef3b_2_to_1.png" + # - prompt: "Swap the furniture inside the red bounding box of the Image 1 by the reference furniture of the Image 2" + # width: 1024 + # height: 489 + # ctrl_img: "/home/ubuntu/ai-toolkit/datasets/start_full/a8ddc0e8-be33-4717-9bec-1a99b35593c3_1_to_2.png" + # ctrl_img2: "/home/ubuntu/ai-toolkit/datasets/reference_full/a8ddc0e8-be33-4717-9bec-1a99b35593c3_1_to_2.png" + neg: "" + seed: 42 + walk_seed: false + guidance_scale: 3 + sample_steps: 25 + sqlite_db_path: "./aitk_db.db" + performance_log_every: 10 +meta: + name: "[name]" + version: "1.0" diff --git a/config/examples/train_lora_omnigen2_24gb.yaml b/config/examples/train_lora_omnigen2_24gb.yaml index 6eb15302d..acfb15f33 100644 --- a/config/examples/train_lora_omnigen2_24gb.yaml +++ b/config/examples/train_lora_omnigen2_24gb.yaml @@ -61,7 +61,7 @@ config: # will probably need this if gpu supports it for omnigen2, other dtypes may not work correctly dtype: bf16 model: - name_or_path: "OmniGen2/OmniGen2 + name_or_path: "OmniGen2/OmniGen2" arch: "omnigen2" quantize_te: true # quantize_only te # quantize: true # quantize transformer diff --git a/config/examples/train_lora_qwen_image_edit_32gb.yaml b/config/examples/train_lora_qwen_image_edit_32gb.yaml index 81a999d79..959afcd21 100644 --- a/config/examples/train_lora_qwen_image_edit_32gb.yaml +++ b/config/examples/train_lora_qwen_image_edit_32gb.yaml @@ -58,12 +58,11 @@ config: name_or_path: "Qwen/Qwen-Image-Edit" arch: "qwen_image_edit" quantize: true - # qtype_te: "qfloat8" Default float8 qquantization + # qtype_te: "qfloat8" # Default float8 qquantization # to use the ARA use the | pipe to point to hf path, or a local path if you have one. # 3bit is required for 32GB qtype: "uint3|qwen_image_edit_torchao_uint3.safetensors" quantize_te: true - qtype_te: "qfloat8" low_vram: true sample: sampler: "flowmatch" # must match train.noise_scheduler diff --git a/config/examples/train_lora_qwen_image_edit_plus2509_32gb_2controls.yaml b/config/examples/train_lora_qwen_image_edit_plus2509_32gb_2controls.yaml new file mode 100644 index 000000000..a01aaafe0 --- /dev/null +++ b/config/examples/train_lora_qwen_image_edit_plus2509_32gb_2controls.yaml @@ -0,0 +1,117 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_qwen_image_edit_plus_20lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word + # Trigger words will not work when caching text embeddings +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + control_path: "/path/to/control/images/folder" + control_path2: "/path/to/control/images/folder2" + # control_path3: "/path/to/control/images/folder3" + caption_ext: "txt" + # default_caption: "a person" # if caching text embeddings, if you don't have captions, this will get cached + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + resolution: [ 512, 768, 1024 ] # qwen image enjoys multiple resolutions + train: + batch_size: 1 + # caching text embeddings is required for 32GB + cache_text_embeddings: true + + steps: 3000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + timestep_type: "weighted" + train_unet: true + train_text_encoder: false # probably won't work with qwen image + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "Qwen/Qwen-Image-Edit-2509" + arch: "qwen_image_edit_plus2509" + quantize: true + # qtype_te: "qfloat8" # Default float8 qquantization + # to use the ARA use the | pipe to point to hf path, or a local path if you have one. + # 3bit is required for 32GB + #! TODO: Quantize and add this to hf + # qtype: "uint3|qwen_image_edit_torchao_uint3.safetensors" + qtype: "qfloat8" + quantize_te: true + low_vram: true + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + samples: + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + ctrl_img2: "/path/to/control2/image.jpg" + # ctrl_img3: "/path/to/control3/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + ctrl_img2: "/path/to/control2/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + ctrl_img2: "/path/to/control2/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + ctrl_img2: "/path/to/control2/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + ctrl_img2: "/path/to/control2/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + ctrl_img2: "/path/to/control2/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + ctrl_img2: "/path/to/control2/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + ctrl_img2: "/path/to/control2/image.jpg" + - prompt: "do the thing to it" + ctrl_img: "/path/to/control/image.jpg" + ctrl_img2: "/path/to/control2/image.jpg" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 3 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index 6449fe611..1f24debd5 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -4,7 +4,7 @@ from .omnigen2 import OmniGen2Model from .flux_kontext import FluxKontextModel from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel -from .qwen_image import QwenImageModel, QwenImageEditModel +from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlus2509Model AI_TOOLKIT_MODELS = [ # put a list of models here @@ -20,4 +20,5 @@ Wan2214bModel, QwenImageModel, QwenImageEditModel, + QwenImageEditPlus2509Model, ] diff --git a/extensions_built_in/diffusion_models/qwen_image/__init__.py b/extensions_built_in/diffusion_models/qwen_image/__init__.py index d8b32a851..8e697dc18 100644 --- a/extensions_built_in/diffusion_models/qwen_image/__init__.py +++ b/extensions_built_in/diffusion_models/qwen_image/__init__.py @@ -1,2 +1,3 @@ from .qwen_image import QwenImageModel -from .qwen_image_edit import QwenImageEditModel \ No newline at end of file +from .qwen_image_edit import QwenImageEditModel +from .qwen_image_edit_plus2509 import QwenImageEditPlus2509Model diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus2509.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus2509.py new file mode 100644 index 000000000..82905f541 --- /dev/null +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus2509.py @@ -0,0 +1,362 @@ +import math +import torch +from .qwen_image import QwenImageModel +from typing import TYPE_CHECKING +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.accelerator import unwrap_model +import torch.nn.functional as F + + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +try: + from diffusers import QwenImageEditPlusPipeline +except ImportError: + raise ImportError( + "QwenImageEditPlusPipeline not found. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt" + ) + + +class QwenImageEditPlus2509Model(QwenImageModel): + arch = "qwen_image_edit_plus2509" + _qwen_image_keep_visual = True + _qwen_pipeline = QwenImageEditPlusPipeline + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["QwenImageTransformer2DModel"] + + # set true for models that encode control image into text embeddings + self.encode_control_in_text_embeddings = True + + def load_model(self): + super().load_model() + + def get_generation_pipeline(self): + scheduler = QwenImageModel.get_train_scheduler() + + pipeline: QwenImageEditPlusPipeline = QwenImageEditPlusPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + processor=self.processor, + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: QwenImageEditPlusPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + self.model.to(self.device_torch, dtype=self.torch_dtype) + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + # Support multiple control images + control_images = [] + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + # resize to width and height + if control_img.size != (gen_config.width, gen_config.height): + control_img = control_img.resize( + (gen_config.width, gen_config.height), Image.BILINEAR + ) + control_images.append(control_img) + + # Support second control image + if hasattr(gen_config, "ctrl_img2") and gen_config.ctrl_img2 is not None: + control_img2 = Image.open(gen_config.ctrl_img2) + control_img2 = control_img2.convert("RGB") + # resize to width and height + if control_img2.size != (gen_config.width, gen_config.height): + control_img2 = control_img2.resize( + (gen_config.width, gen_config.height), Image.BILINEAR + ) + control_images.append(control_img2) + + # Support third control image if needed + if hasattr(gen_config, "ctrl_img3") and gen_config.ctrl_img3 is not None: + control_img3 = Image.open(gen_config.ctrl_img3) + control_img3 = control_img3.convert("RGB") + # resize to width and height + if control_img3.size != (gen_config.width, gen_config.height): + control_img3 = control_img3.resize( + (gen_config.width, gen_config.height), Image.BILINEAR + ) + control_images.append(control_img3) + + # Use the list of images or None if no control images + control_input = control_images if control_images else None + + # flush for low vram if we are doing that + flush_between_steps = self.model_config.low_vram + + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + if flush_between_steps: + flush() + latents = callback_kwargs["latents"] + + return {"latents": latents} + + img = pipeline( + image=control_input, + prompt_embeds=conditional_embeds.text_embeds, + prompt_embeds_mask=conditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + true_cfg_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + callback_on_step_end=callback_on_step_end, + **extra, + ).images[0] + return img + + def condition_noisy_latents( + self, latents: torch.Tensor, batch: "DataLoaderBatchDTO" + ): + with torch.no_grad(): + # Support multiple control tensors + control_tensors = [] + + # Primary control tensor + if batch.control_tensor is not None: + control_tensors.append(batch.control_tensor) + + # Secondary control tensor + if batch.control_tensor2 is not None: + control_tensors.append(batch.control_tensor2) + + # Third control tensor + if batch.control_tensor3 is not None: + control_tensors.append(batch.control_tensor3) + + if not control_tensors: + return latents + + self.vae.to(self.device_torch) + + # Process all control tensors + control_latents = [] + for control_tensor in control_tensors: + # we are not packed here, so we just need to pass them so we can pack them later + control_tensor = control_tensor * 2 - 1 + control_tensor = control_tensor.to( + self.vae_device_torch, dtype=self.torch_dtype + ) + + # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it + if batch.tensor is not None: + target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3] + else: + # When caching latents, batch.tensor is None. We get the size from the file_items instead. + target_h = batch.file_items[0].crop_height + target_w = batch.file_items[0].crop_width + + # Ensure control image matches the main image dimensions so VAE latents align + if ( + control_tensor.shape[2] != target_h + or control_tensor.shape[3] != target_w + ): + control_tensor = F.interpolate( + control_tensor, size=(target_h, target_w), mode="bilinear" + ) + + control_latent = self.encode_images(control_tensor).to( + latents.device, latents.dtype + ) + control_latents.append(control_latent) + + # Concatenate all control latents along the channel dimension + if control_latents: + all_control_latents = torch.cat(control_latents, dim=1) + latents = torch.cat((latents, all_control_latents), dim=1) + + return latents.detach() + + def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds: + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + # Handle multiple control images + processed_control_images = None + if control_images is not None: + # control images are 0 - 1 scale, shape (bs, ch, height, width) + # For multiple images, we need to process each one + if isinstance(control_images, list): + processed_control_images = [] + for img in control_images: + # images are always run through at ~384x384 for conditioning (Edit Plus uses smaller res for VL) + target_area = 384 * 384 + ratio = img.shape[2] / img.shape[3] + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + processed_img = F.interpolate( + img, size=(int(height), int(width)), mode="bilinear" + ) + processed_control_images.append(processed_img) + else: + # Single image case (backward compatibility) + target_area = ( + 384 * 384 + ) # Edit Plus uses smaller conditioning resolution + ratio = control_images.shape[2] / control_images.shape[3] + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + processed_control_images = F.interpolate( + control_images, size=(int(height), int(width)), mode="bilinear" + ) + + prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt( + prompt, + image=processed_control_images, + device=self.device_torch, + num_images_per_prompt=1, + ) + pe = PromptEmbeds(prompt_embeds) + pe.attention_mask = prompt_embeds_mask + return pe + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs, + ): + # For Edit Plus, we need to handle multiple control images in the latent concatenation + # The latents are packed as [main_latents, control1_latents, control2_latents, ...] + + # Split the latents - first part is main latents, rest are control latents + main_channels = self.transformer.config.in_channels // 4 + main_latents = latent_model_input[:, :main_channels, :, :] + + # The rest are control latents - we need to split them by number of control images + if latent_model_input.shape[1] > main_channels: + control_latents = latent_model_input[:, main_channels:, :, :] + # Assuming each control image contributes 16 channels + num_control_images = control_latents.shape[1] // main_channels + control_list = torch.chunk(control_latents, num_control_images, dim=1) + else: + control_list = [] + + batch_size, num_channels_latents, height, width = main_latents.shape + + # pack main image tokens + main_latents_packed = main_latents.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + main_latents_packed = main_latents_packed.permute(0, 2, 4, 1, 3, 5) + main_latents_packed = main_latents_packed.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + + # pack control image tokens + control_packed_list = [] + img_shapes_per_batch = [(1, height // 2, width // 2)] # Main image shape + + for control in control_list: + control_height, control_width = control.shape[2], control.shape[3] + control_packed = control.view( + batch_size, + num_channels_latents, + control_height // 2, + 2, + control_width // 2, + 2, + ) + control_packed = control_packed.permute(0, 2, 4, 1, 3, 5) + control_packed = control_packed.reshape( + batch_size, + (control_height // 2) * (control_width // 2), + num_channels_latents * 4, + ) + control_packed_list.append(control_packed) + img_shapes_per_batch.append((1, control_height // 2, control_width // 2)) + + # Concatenate all latents (main + controls) along sequence dimension + if control_packed_list: + latent_model_input = torch.cat( + [main_latents_packed] + control_packed_list, dim=1 + ) + else: + latent_model_input = main_latents_packed + + # img_shapes for transformer - repeated for each batch item + img_shapes = [img_shapes_per_batch] * batch_size + + prompt_embeds_mask = text_embeddings.attention_mask.to( + self.device_torch, dtype=torch.int64 + ) + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() + enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype) + # prompt_embeds_mask = text_embeddings.attention_mask.to( + # self.device_torch, dtype=torch.int64 + # ) + + noise_pred = self.transformer( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=enc_hs, + encoder_hidden_states_mask=prompt_embeds_mask, + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + return_dict=False, + **kwargs, + )[0] + + # Extract only the main latent predictions (first part of the sequence) + main_seq_len = (height // 2) * (width // 2) + noise_pred = noise_pred[:, :main_seq_len] + + # unpack main latents only + noise_pred = noise_pred.view( + batch_size, height // 2, width // 2, num_channels_latents, 2, 2 + ) + noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5) + noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width) + return noise_pred diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 16c392d7f..5ce14515a 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -129,24 +129,48 @@ def cache_sample_prompts(self): prompt=prompt, # it will autoparse the prompt negative_prompt=sample_item.neg, output_path=output_path, - ctrl_img=sample_item.ctrl_img + ctrl_img=sample_item.ctrl_img, + ctrl_img2=getattr(sample_item, "ctrl_img2", None), + ctrl_img3=getattr(sample_item, "ctrl_img3", None), ) # see if we need to encode the control images - if self.sd.encode_control_in_text_embeddings and gen_img_config.ctrl_img is not None: - ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img = ( - TF.to_tensor(ctrl_img) - .unsqueeze(0) - .to(self.sd.device_torch, dtype=self.sd.torch_dtype) - ) + if self.sd.encode_control_in_text_embeddings and ( + gen_img_config.ctrl_img is not None or + gen_img_config.ctrl_img2 is not None or + gen_img_config.ctrl_img3 is not None + ): + control_images = [] + if gen_img_config.ctrl_img is not None: + ci1 = Image.open(gen_img_config.ctrl_img).convert("RGB") + ci1 = ( + TF.to_tensor(ci1) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + control_images.append(ci1) + if getattr(gen_img_config, 'ctrl_img2', None) is not None: + ci2 = Image.open(gen_img_config.ctrl_img2).convert("RGB") + ci2 = ( + TF.to_tensor(ci2) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + control_images.append(ci2) + if getattr(gen_img_config, 'ctrl_img3', None) is not None: + ci3 = Image.open(gen_img_config.ctrl_img3).convert("RGB") + ci3 = ( + TF.to_tensor(ci3) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + control_images.append(ci3) positive = self.sd.encode_prompt( gen_img_config.prompt, - control_images=ctrl_img + control_images=control_images if len(control_images) > 1 else (control_images[0] if len(control_images) == 1 else None) ).to('cpu') negative = self.sd.encode_prompt( gen_img_config.negative_prompt, - control_images=ctrl_img + control_images=control_images if len(control_images) > 1 else (control_images[0] if len(control_images) == 1 else None) ).to('cpu') else: positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu') @@ -1007,14 +1031,24 @@ def get_prior_prediction( embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype) else: prompt_kwargs = {} - if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None: - prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.encode_control_in_text_embeddings and ( + batch.control_tensor is not None or batch.control_tensor2 is not None or batch.control_tensor3 is not None + ): + control_images_list = [] + if batch.control_tensor is not None: + control_images_list.append(batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype)) + if batch.control_tensor2 is not None: + control_images_list.append(batch.control_tensor2.to(self.sd.device_torch, dtype=self.sd.torch_dtype)) + if batch.control_tensor3 is not None: + control_images_list.append(batch.control_tensor3.to(self.sd.device_torch, dtype=self.sd.torch_dtype)) + prompt_kwargs['control_images'] = control_images_list if len(control_images_list) > 1 else control_images_list[0] embeds_to_use = self.sd.encode_prompt( prompt_list, - long_prompts=self.do_long_prompts).to( - self.device_torch, - dtype=dtype, + long_prompts=self.do_long_prompts, **prompt_kwargs + ).to( + self.device_torch, + dtype=dtype ).detach() # dont use network on this @@ -1388,8 +1422,17 @@ def get_adapter_multiplier(): with self.timer('encode_prompt'): unconditional_embeds = None prompt_kwargs = {} - if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None: - prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.encode_control_in_text_embeddings and ( + batch.control_tensor is not None or batch.control_tensor2 is not None or batch.control_tensor3 is not None + ): + control_images_list = [] + if batch.control_tensor is not None: + control_images_list.append(batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype)) + if batch.control_tensor2 is not None: + control_images_list.append(batch.control_tensor2.to(self.sd.device_torch, dtype=self.sd.torch_dtype)) + if batch.control_tensor3 is not None: + control_images_list.append(batch.control_tensor3.to(self.sd.device_torch, dtype=self.sd.torch_dtype)) + prompt_kwargs['control_images'] = control_images_list if len(control_images_list) > 1 else control_images_list[0] if self.train_config.unload_text_encoder or self.is_caching_text_embeddings: with torch.set_grad_enabled(False): if batch.prompt_embeds is not None: @@ -1982,4 +2025,4 @@ def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchD self.end_of_training_loop() - return loss_dict + return loss_dict \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 39fa3dd01..340262d1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torchao==0.10.0 safetensors git+https://github.com/jaretburkett/easy_dwpose.git -git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63 +git+https://github.com/huggingface/diffusers@579673501535753e382974e1037498cf61c3c7a1 transformers==4.52.4 lycoris-lora==1.8.3 flatten_json diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ca0b3bf2b..335009cc8 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -1,6 +1,6 @@ import os import time -from typing import List, Optional, Literal, Tuple, Union, TYPE_CHECKING, Dict +from typing import List, Optional, Literal, Union, TYPE_CHECKING, Dict import random import torch @@ -55,6 +55,8 @@ def __init__( self.fps: int = kwargs.get('fps', sample_config.fps) self.num_frames: int = kwargs.get('num_frames', sample_config.num_frames) self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None) + self.ctrl_img2: Optional[str] = kwargs.get('ctrl_img2', None) + self.ctrl_img3: Optional[str] = kwargs.get('ctrl_img3', None) self.ctrl_idx: int = kwargs.get('ctrl_idx', 0) self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier) # convert to a number if it is a string @@ -478,7 +480,7 @@ def __init__(self, **kwargs): # if it is set explicitly to false, leave it false. if ema_config is not None and ema_config.get('use_ema', False): ema_config['use_ema'] = True - print(f"Using EMA") + print("Using EMA") else: ema_config = {'use_ema': False} @@ -531,7 +533,25 @@ def __init__(self, **kwargs): self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1) -ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21'] +ModelArch = Literal[ + "sd1", + "sd2", + "sd3", + "sdxl", + "pixart", + "pixart_sigma", + "auraflow", + "flux", + "flex1", + "flex2", + "lumina2", + "vega", + "ssd", + "wan21", + "qwen_image", + "qwen_image_edit", + "qwen_image_edit_plus2509", +] class ModelConfig: @@ -547,6 +567,9 @@ def __init__(self, **kwargs): self.is_v3: bool = kwargs.get('is_v3', False) self.is_flux: bool = kwargs.get('is_flux', False) self.is_lumina2: bool = kwargs.get('is_lumina2', False) + self.is_qwen_image: bool = kwargs.get('is_qwen_image', False) + self.is_qwen_image_edit: bool = kwargs.get('is_qwen_image_edit', False) + self.is_qwen_image_edit_plus2509: bool = kwargs.get('is_qwen_image_edit_plus2509', False) if self.is_pixart_sigma: self.is_pixart = True self.use_flux_cfg = kwargs.get('use_flux_cfg', False) @@ -659,6 +682,12 @@ def __init__(self, **kwargs): self.is_flux = True elif self.arch == 'lumina2': self.is_lumina2 = True + elif self.arch == 'qwen_image': + self.is_qwen_image = True + elif self.arch == 'qwen_image_edit': + self.is_qwen_image_edit = True + elif self.arch == 'qwen_image_edit_plus2509': + self.is_qwen_image_edit_plus2509 = True elif self.arch == 'vega': self.is_vega = True elif self.arch == 'ssd': @@ -766,7 +795,7 @@ def __init__(self, **kwargs): self.targets: List[SliderTargetConfig] = [] targets = [SliderTargetConfig(**target) for target in targets] # do permutations if shuffle is true - print(f"Building slider targets") + print("Building slider targets") for target in targets: if target.shuffle: target_permutations = get_slider_target_permutations(target, max_permutations=8) @@ -809,6 +838,7 @@ def __init__(self, **kwargs): self.random_scale: bool = kwargs.get('random_scale', False) self.random_crop: bool = kwargs.get('random_crop', False) self.resolution: int = kwargs.get('resolution', 512) + self.keep_native_size: bool = kwargs.get('keep_native_size', False) self.scale: float = kwargs.get('scale', 1.0) self.buckets: bool = kwargs.get('buckets', True) self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64) @@ -825,6 +855,12 @@ def __init__(self, **kwargs): self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc if self.control_path == '': self.control_path = None + self.control_path2: Union[str,List[str]] = kwargs.get('control_path2', None) + if self.control_path2 == '': + self.control_path2 = None + self.control_path3: Union[str,List[str]] = kwargs.get('control_path3', None) + if self.control_path3 == '': + self.control_path3 = None # color for transparent reigon of control images with transparency self.control_transparent_color: List[int] = kwargs.get('control_transparent_color', [0, 0, 0]) @@ -861,7 +897,7 @@ def __init__(self, **kwargs): has_augmentations = self.augmentations is not None and len(self.augmentations) > 0 if (len(self.augments) > 0 or has_augmentations) and (self.cache_latents or self.cache_latents_to_disk): - print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False") + print("WARNING: Augments are not supported with caching latents. Setting cache_latents to False") self.cache_latents = False self.cache_latents_to_disk = False @@ -916,6 +952,14 @@ def __init__(self, **kwargs): self.do_i2v: bool = kwargs.get('do_i2v', True) # do image to video on models that are both t2i and i2v capable + if self.keep_native_size: + if self.square_crop: + raise ValueError("keep_native_size and square_crop cannot be used together") + if self.random_crop: + raise ValueError("keep_native_size and random_crop cannot be used together") + if self.random_scale: + raise ValueError("keep_native_size and random_scale cannot be used together") + def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: """ @@ -966,6 +1010,8 @@ def __init__( extra_values: List[float] = None, # extra values to save with prompt file logger: Optional[EmptyLogger] = None, ctrl_img: Optional[str] = None, # control image for controlnet + ctrl_img2: Optional[str] = None, # second control image for controlnet + ctrl_img3: Optional[str] = None, # third control image for controlnet num_frames: int = 1, fps: int = 15, ctrl_idx: int = 0 @@ -1000,6 +1046,8 @@ def __init__( self.num_frames = num_frames self.fps = fps self.ctrl_img = ctrl_img + self.ctrl_img2 = ctrl_img2 + self.ctrl_img3 = ctrl_img3 self.ctrl_idx = ctrl_idx @@ -1209,6 +1257,10 @@ def _process_prompt_string(self): self.fps = int(content) elif flag == 'ctrl_img': self.ctrl_img = content + elif flag == 'ctrl_img2': + self.ctrl_img2 = content + elif flag == 'ctrl_img3': + self.ctrl_img3 = content elif flag == 'ctrl_idx': self.ctrl_idx = int(content) @@ -1258,8 +1310,8 @@ def validate_configs( raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.") # qwen image edit cannot cache text embeddings - if model_config.arch == 'qwen_image_edit': + if model_config.arch == 'qwen_image_edit' or model_config.arch == 'qwen_image_edit_plus2509': if train_config.unload_text_encoder: raise ValueError("Cannot cache unload text encoder with qwen_image_edit model. Control images are encoded with text embeddings. You can cache the text embeddings though") - + \ No newline at end of file diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index a7ed3759c..27d527f4c 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -144,6 +144,8 @@ def __init__(self, **kwargs): self.tensor: Union[torch.Tensor, None] = None self.latents: Union[torch.Tensor, None] = None self.control_tensor: Union[torch.Tensor, None] = None + self.control_tensor2: Union[torch.Tensor, None] = None + self.control_tensor3: Union[torch.Tensor, None] = None self.clip_image_tensor: Union[torch.Tensor, None] = None self.mask_tensor: Union[torch.Tensor, None] = None self.unaugmented_tensor: Union[torch.Tensor, None] = None @@ -161,6 +163,8 @@ def __init__(self, **kwargs): if is_latents_cached: self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) self.control_tensor: Union[torch.Tensor, None] = None + self.control_tensor2: Union[torch.Tensor, None] = None + self.control_tensor3: Union[torch.Tensor, None] = None self.prompt_embeds: Union[PromptEmbeds, None] = None # if self.file_items[0].control_tensor is not None: # if any have a control tensor, we concatenate them @@ -177,7 +181,45 @@ def __init__(self, **kwargs): control_tensors.append(torch.zeros_like(base_control_tensor)) else: control_tensors.append(x.control_tensor) - self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors]) + self.control_tensor = torch.cat( + [x.unsqueeze(0) for x in control_tensors] + ) + + # Handle control_tensor2 + if any([x.control_tensor2 is not None for x in self.file_items]): + # find one to use as a base + base_control_tensor2 = None + for x in self.file_items: + if x.control_tensor2 is not None: + base_control_tensor2 = x.control_tensor2 + break + control_tensors2 = [] + for x in self.file_items: + if x.control_tensor2 is None: + control_tensors2.append(torch.zeros_like(base_control_tensor2)) + else: + control_tensors2.append(x.control_tensor2) + self.control_tensor2 = torch.cat( + [x.unsqueeze(0) for x in control_tensors2] + ) + + # Handle control_tensor3 + if any([x.control_tensor3 is not None for x in self.file_items]): + # find one to use as a base + base_control_tensor3 = None + for x in self.file_items: + if x.control_tensor3 is not None: + base_control_tensor3 = x.control_tensor3 + break + control_tensors3 = [] + for x in self.file_items: + if x.control_tensor3 is None: + control_tensors3.append(torch.zeros_like(base_control_tensor3)) + else: + control_tensors3.append(x.control_tensor3) + self.control_tensor3 = torch.cat( + [x.unsqueeze(0) for x in control_tensors3] + ) self.inpaint_tensor: Union[torch.Tensor, None] = None if any([x.inpaint_tensor is not None for x in self.file_items]): @@ -329,4 +371,4 @@ def dataset_config(self) -> 'DatasetConfig': if len(self.file_items) > 0: return self.file_items[0].dataset_config else: - return None + return None \ No newline at end of file diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 29b2aced6..3cf248e8b 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -221,6 +221,45 @@ def setup_buckets(self: 'AiToolkitDataset', quiet=False): if file_item.has_point_of_interest: # Attempt to process the poi if we can. It wont process if the image is smaller than the resolution did_process_poi = file_item.setup_poi_bucket() + + if config.keep_native_size and self.batch_size == 1: + #! TODO: Currently always enforce divisibility by 32 for Qwen: 16 VAE * 2 patch + if True: + # rescale using config.scale + width = int(width * config.scale) + height = int(height * config.scale) + # bucket_tolerance is 32 for Qwen (it's for the VAE latents) + sc = bucket_tolerance # already set from sd.get_bucket_divisibility() + new_w = width - (width % sc) + new_h = height - (height % sc) + # guard against degenerate case + if new_w <= 0 or new_h <= 0: + new_w = max(sc, (width // sc) * sc) + new_h = max(sc, (height // sc) * sc) + # center-crop minimally to divisibility + dx = max(0, (width - new_w) // 2) + dy = max(0, (height - new_h) // 2) + + file_item.scale_to_width = width + file_item.scale_to_height = height + file_item.crop_width = new_w + file_item.crop_height = new_h + file_item.crop_x = dx + file_item.crop_y = dy + else: + file_item.scale_to_width = width + file_item.scale_to_height = height + file_item.crop_width = width + file_item.crop_height = height + file_item.crop_x = 0 + file_item.crop_y = 0 + + bucket_key = f"{width}x{height}" + if bucket_key not in self.buckets: + self.buckets[bucket_key] = Bucket(width, height) + self.buckets[bucket_key].file_list_idx.append(idx) + continue + if self.dataset_config.square_crop: # we scale first so smallest size matches resolution scale_factor_x = resolution / width @@ -847,9 +886,19 @@ class ControlFileItemDTOMixin: def __init__(self: 'FileItemDTO', *args, **kwargs): if hasattr(super(), '__init__'): super().__init__(*args, **kwargs) + self.has_control_image = False self.control_path: Union[str, List[str], None] = None self.control_tensor: Union[torch.Tensor, None] = None + + self.has_control_image2 = False + self.control_path2: Union[str, List[str], None] = None + self.control_tensor2: Union[torch.Tensor, None] = None + + self.has_control_image3 = False + self.control_path3: Union[str, List[str], None] = None + self.control_tensor3: Union[torch.Tensor, None] = None + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) self.full_size_control_images = False if dataset_config.control_path is not None: @@ -876,76 +925,129 @@ def __init__(self: 'FileItemDTO', *args, **kwargs): # only do one self.control_path = self.control_path[0] + if dataset_config.control_path2 is not None: + control_path_list = dataset_config.control_path2 + if not isinstance(control_path_list, list): + control_path_list = [control_path_list] + self.full_size_control_images = dataset_config.full_size_control_images + # we are using control images + img_path = kwargs.get('path', None) + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + + found_control_images = [] + for control_path in control_path_list: + for ext in img_ext_list: + if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)): + found_control_images.append(os.path.join(control_path, file_name_no_ext + ext)) + self.has_control_image2 = True + break + self.control_path2 = found_control_images + if len(self.control_path2) == 0: + self.control_path2 = None + elif len(self.control_path2) == 1: + # only do one + self.control_path2 = self.control_path2[0] + + if dataset_config.control_path3 is not None: + control_path_list = dataset_config.control_path3 + if not isinstance(control_path_list, list): + control_path_list = [control_path_list] + self.full_size_control_images = dataset_config.full_size_control_images + # we are using control images + img_path = kwargs.get('path', None) + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + + found_control_images = [] + for control_path in control_path_list: + for ext in img_ext_list: + if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)): + found_control_images.append(os.path.join(control_path, file_name_no_ext + ext)) + self.has_control_image3 = True + def load_control_image(self: 'FileItemDTO'): + # Load primary control image control_tensors = [] - control_path_list = self.control_path - if not isinstance(self.control_path, list): - control_path_list = [self.control_path] - - for control_path in control_path_list: - try: - img = Image.open(control_path) - img = exif_transpose(img) - - if img.mode in ("RGBA", "LA"): - # Create a background with the specified transparent color - transparent_color = tuple(self.dataset_config.control_transparent_color) - background = Image.new("RGB", img.size, transparent_color) - # Paste the image on top using its alpha channel as mask - background.paste(img, mask=img.getchannel("A")) - img = background - else: - # Already no alpha channel - img = img.convert("RGB") - except Exception as e: - print_acc(f"Error: {e}") - print_acc(f"Error loading image: {control_path}") + if self.control_path is not None: + control_path_list = self.control_path + if not isinstance(self.control_path, list): + control_path_list = [self.control_path] - if not self.full_size_control_images: - # we just scale them to 512x512: - w, h = img.size - img = img.resize((512, 512), Image.BICUBIC) + for control_path in control_path_list: + try: + img = Image.open(control_path) + img = exif_transpose(img) + + if img.mode in ("RGBA", "LA"): + # Create a background with the specified transparent color + transparent_color = tuple( + self.dataset_config.control_transparent_color + ) + background = Image.new("RGB", img.size, transparent_color) + # Paste the image on top using its alpha channel as mask + background.paste(img, mask=img.getchannel("A")) + img = background + else: + # Already no alpha channel + img = img.convert("RGB") + except Exception as e: + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {control_path}") - else: - w, h = img.size - if w > h and self.scale_to_width < self.scale_to_height: - # throw error, they should match - raise ValueError( - f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") - elif h > w and self.scale_to_height < self.scale_to_width: - # throw error, they should match - raise ValueError( - f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + if not self.full_size_control_images: + # we just scale them to 512x512: + w, h = img.size + img = img.resize((512, 512), Image.BICUBIC) - if self.flip_x: - # do a flip - img = img.transpose(Image.FLIP_LEFT_RIGHT) - if self.flip_y: - # do a flip - img = img.transpose(Image.FLIP_TOP_BOTTOM) + else: + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}" + ) + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}" + ) - if self.dataset_config.buckets: - # scale and crop based on file item - img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) - # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) - # crop - img = img.crop(( - self.crop_x, - self.crop_y, - self.crop_x + self.crop_width, - self.crop_y + self.crop_height - )) + if self.flip_x: + # do a flip + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize( + (self.scale_to_width, self.scale_to_height), Image.BICUBIC + ) + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop + img = img.crop( + ( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height, + ) + ) + else: + raise Exception( + "Control images not supported for non-bucket datasets" + ) + transform = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + if self.aug_replay_spatial_transforms: + tensor = self.augment_spatial_control(img, transform=transform) else: - raise Exception("Control images not supported for non-bucket datasets") - transform = transforms.Compose([ - transforms.ToTensor(), - ]) - if self.aug_replay_spatial_transforms: - tensor = self.augment_spatial_control(img, transform=transform) - else: - tensor = transform(img) - control_tensors.append(tensor) - + tensor = transform(img) + control_tensors.append(tensor) + if len(control_tensors) == 0: self.control_tensor = None elif len(control_tensors) == 1: @@ -953,8 +1055,160 @@ def load_control_image(self: 'FileItemDTO'): else: self.control_tensor = torch.stack(control_tensors, dim=0) + # Load second control image + if self.has_control_image2 and self.control_path2 is not None: + control_path2_list = self.control_path2 + if not isinstance(self.control_path2, list): + control_path2_list = [self.control_path2] + + control_tensors2 = [] + for control_path in control_path2_list: + try: + img = Image.open(control_path) + img = exif_transpose(img) + + if img.mode in ("RGBA", "LA"): + transparent_color = tuple( + self.dataset_config.control_transparent_color + ) + background = Image.new("RGB", img.size, transparent_color) + background.paste(img, mask=img.getchannel("A")) + img = background + else: + img = img.convert("RGB") + except Exception as e: + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {control_path}") + + if not self.full_size_control_images: + w, h = img.size + img = img.resize((512, 512), Image.BICUBIC) + else: + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}" + ) + elif h > w and self.scale_to_height < self.scale_to_width: + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}" + ) + + if self.flip_x: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + if self.dataset_config.buckets: + img = img.resize( + (self.scale_to_width, self.scale_to_height), Image.BICUBIC + ) + img = img.crop( + ( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height, + ) + ) + else: + raise Exception( + "Control images not supported for non-bucket datasets" + ) + + transform = transforms.Compose([transforms.ToTensor()]) + if self.aug_replay_spatial_transforms: + tensor = self.augment_spatial_control(img, transform=transform) + else: + tensor = transform(img) + control_tensors2.append(tensor) + + if len(control_tensors2) == 0: + self.control_tensor2 = None + elif len(control_tensors2) == 1: + self.control_tensor2 = control_tensors2[0] + else: + self.control_tensor2 = torch.stack(control_tensors2, dim=0) + + # Load third control image + if self.has_control_image3 and self.control_path3 is not None: + control_path3_list = self.control_path3 + if not isinstance(self.control_path3, list): + control_path3_list = [self.control_path3] + + control_tensors3 = [] + for control_path in control_path3_list: + try: + img = Image.open(control_path) + img = exif_transpose(img) + + if img.mode in ("RGBA", "LA"): + transparent_color = tuple( + self.dataset_config.control_transparent_color + ) + background = Image.new("RGB", img.size, transparent_color) + background.paste(img, mask=img.getchannel("A")) + img = background + else: + img = img.convert("RGB") + except Exception as e: + print_acc(f"Error: {e}") + print_acc(f"Error loading image: {control_path}") + + if not self.full_size_control_images: + w, h = img.size + img = img.resize((512, 512), Image.BICUBIC) + else: + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}" + ) + elif h > w and self.scale_to_height < self.scale_to_width: + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}" + ) + + if self.flip_x: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + if self.dataset_config.buckets: + img = img.resize( + (self.scale_to_width, self.scale_to_height), Image.BICUBIC + ) + img = img.crop( + ( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height, + ) + ) + else: + raise Exception( + "Control images not supported for non-bucket datasets" + ) + + transform = transforms.Compose([transforms.ToTensor()]) + if self.aug_replay_spatial_transforms: + tensor = self.augment_spatial_control(img, transform=transform) + else: + tensor = transform(img) + control_tensors3.append(tensor) + + if len(control_tensors3) == 0: + self.control_tensor3 = None + elif len(control_tensors3) == 1: + self.control_tensor3 = control_tensors3[0] + else: + self.control_tensor3 = torch.stack(control_tensors3, dim=0) + def cleanup_control(self: 'FileItemDTO'): self.control_tensor = None + self.control_tensor2 = None + self.control_tensor3 = None class ClipImageFileItemDTOMixin: @@ -1819,9 +2073,14 @@ def get_text_embedding_info_dict(self: 'FileItemDTO'): ("text_embedding_space_version", self.text_embedding_space_version), ("text_embedding_version", self.text_embedding_version), ]) - # if we have a control image, cache the path - if self.encode_control_in_text_embeddings and self.control_path is not None: - item["control_path"] = self.control_path + # if we have control images, cache the paths + if self.encode_control_in_text_embeddings: + if self.control_path is not None: + item["control_path"] = self.control_path + if hasattr(self, "control_path2") and self.control_path2 is not None: + item["control_path2"] = self.control_path2 + if hasattr(self, "control_path3") and self.control_path3 is not None: + item["control_path3"] = self.control_path3 return item def get_text_embedding_path(self: 'FileItemDTO', recalculate=False): @@ -1882,19 +2141,64 @@ def cache_text_embeddings(self: 'AiToolkitDataset'): did_move = True if file_item.encode_control_in_text_embeddings: - if file_item.control_path is None: - raise Exception(f"Could not find a control image for {file_item.path} which is needed for this model") - # load the control image and feed it into the text encoder - ctrl_img = Image.open(file_item.control_path).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img = ( - TF.to_tensor(ctrl_img) - .unsqueeze(0) - .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + # Collect all control images for multi-image support + control_images = [] + + if file_item.control_path is not None: + ctrl_img = Image.open(file_item.control_path).convert("RGB") + ctrl_img_tensor = ( + TF.to_tensor(ctrl_img) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + control_images.append(ctrl_img_tensor) + + if ( + hasattr(file_item, "control_path2") + and file_item.control_path2 is not None + ): + ctrl_img2 = Image.open(file_item.control_path2).convert( + "RGB" + ) + ctrl_img2_tensor = ( + TF.to_tensor(ctrl_img2) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + control_images.append(ctrl_img2_tensor) + + if ( + hasattr(file_item, "control_path3") + and file_item.control_path3 is not None + ): + ctrl_img3 = Image.open(file_item.control_path3).convert( + "RGB" + ) + ctrl_img3_tensor = ( + TF.to_tensor(ctrl_img3) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + control_images.append(ctrl_img3_tensor) + + if not control_images: + raise Exception( + f"Could not find any control images for {file_item.path} which is needed for this model" + ) + + # Use list for multiple images or single image for backward compatibility + ctrl_input = ( + control_images + if len(control_images) > 1 + else control_images[0] + ) + prompt_embeds: PromptEmbeds = self.sd.encode_prompt( + file_item.caption, control_images=ctrl_input ) - prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img) else: - prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption) + prompt_embeds: PromptEmbeds = self.sd.encode_prompt( + file_item.caption + ) # save it prompt_embeds.save(text_embedding_path) del prompt_embeds @@ -2127,4 +2431,4 @@ def setup_controls(self: 'AiToolkitDataset'): self.control_generator.cleanup() self.control_generator = None - flush() + flush() \ No newline at end of file diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index af5ea06fe..d8795389e 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -503,13 +503,37 @@ def generate_images( else: ctrl_img = None # load the control image if out model uses it in text encoding - if gen_config.ctrl_img is not None and self.encode_control_in_text_embeddings: - ctrl_img = Image.open(gen_config.ctrl_img).convert("RGB") - # convert to 0 to 1 tensor + if self.encode_control_in_text_embeddings: + control_images = [] + if gen_config.ctrl_img is not None: + ctr_img = Image.open(gen_config.ctrl_img).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_tensor = ( + TF.to_tensor(ctr_img) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + control_images.append(ctrl_img_tensor) + if gen_config.ctrl_img2 is not None: + ctr_img2 = Image.open(gen_config.ctrl_img2).convert("RGB") + ctrl_img2_tensor = ( + TF.to_tensor(ctr_img2) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + control_images.append(ctrl_img2_tensor) + if gen_config.ctrl_img3 is not None: + ctr_img3 = Image.open(gen_config.ctrl_img3).convert("RGB") + ctrl_img3_tensor = ( + TF.to_tensor(ctr_img3) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + control_images.append(ctrl_img3_tensor) ctrl_img = ( - TF.to_tensor(ctrl_img) - .unsqueeze(0) - .to(self.device_torch, dtype=self.torch_dtype) + control_images + if len(control_images) > 1 + else (control_images[0] if control_images else None) ) # encode the prompt ourselves so we can do fun stuff with embeddings if isinstance(self.adapter, CustomAdapter): @@ -1535,4 +1559,4 @@ def get_base_model_version(self) -> str: def get_model_to_train(self): # called to get model to attach LoRAs to. Can be overridden in child classes - return self.unet + return self.unet \ No newline at end of file diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 7735ae35a..8d78be40c 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -335,6 +335,28 @@ export const modelArchs: ModelArch[] = [ '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors', }, }, + { + name: 'qwen_image_edit_plus2509', + label: 'Qwen-Image-Edit-Plus-2509', + group: 'instruction', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit-2509', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + }, + disableSections: ['network.conv'], + additionalSections: ['datasets.control_path', 'sample.ctrl_img', 'model.low_vram'], + accuracyRecoveryAdapters: { + //! TODO: Quantize and add this to hf + '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_plus2509_torchao_uint3.safetensors', + }, + }, { name: 'hidream', label: 'HiDream',