diff --git a/README.md b/README.md index d987f585c..e3118fdf0 100644 --- a/README.md +++ b/README.md @@ -287,45 +287,12 @@ You will instantiate a UI that will let you upload your images, caption them, tr ## Training in RunPod -Example RunPod template: **runpod/pytorch:2.2.0-py3.10-cuda12.1.1-devel-ubuntu22.04** -> You need a minimum of 24GB VRAM, pick a GPU by your preference. +If you would like to use Runpod, but have not signed up yet, please consider using [my Runpod affiliate link](https://runpod.io?ref=h0y9jyr2) to help support this project. -#### Example config ($0.5/hr): -- 1x A40 (48 GB VRAM) -- 19 vCPU 100 GB RAM -#### Custom overrides (you need some storage to clone FLUX.1, store datasets, store trained models and samples): -- ~120 GB Disk -- ~120 GB Pod Volume -- Start Jupyter Notebook +I maintain an official Runpod Pod template here which can be accessed [here](https://console.runpod.io/deploy?template=0fqzfjy6f3&ref=h0y9jyr2). -### 1. Setup -``` -git clone https://github.com/ostris/ai-toolkit.git -cd ai-toolkit -git submodule update --init --recursive -python -m venv venv -source venv/bin/activate -pip install torch -pip install -r requirements.txt -pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues -``` -### 2. Upload your dataset -- Create a new folder in the root, name it `dataset` or whatever you like. -- Drag and drop your .jpg, .jpeg, or .png images and .txt files inside the newly created dataset folder. - -### 3. Login into Hugging Face with an Access Token -- Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev). -- Run ```huggingface-cli login``` and paste your token. - -### 4. Training -- Copy an example config file located at ```config/examples``` to the config folder and rename it to ```whatever_you_want.yml```. -- Edit the config following the comments in the file. -- Change ```folder_path: "/path/to/images/folder"``` to your dataset path like ```folder_path: "/workspace/ai-toolkit/your-dataset"```. -- Run the file: ```python run.py config/whatever_you_want.yml```. - -### Screenshot from RunPod -RunPod Training Screenshot +I have also created a short video showing how to get started using AI Toolkit with Runpod [here](https://youtu.be/HBNeS-F6Zz8). ## Training in Modal diff --git a/config/examples/train_lora_qwen_image_edit_2509_32gb.yaml b/config/examples/train_lora_qwen_image_edit_2509_32gb.yaml new file mode 100644 index 000000000..845ac92b6 --- /dev/null +++ b/config/examples/train_lora_qwen_image_edit_2509_32gb.yaml @@ -0,0 +1,105 @@ +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_qwen_image_edit_2509_lora_v1" + process: + - type: 'diffusion_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + # can do up to 3 control image folders, file names must match target file names, but aspect/size can be different + control_path: + - "/path/to/control/images/folder1" + - "/path/to/control/images/folder2" + - "/path/to/control/images/folder3" + caption_ext: "txt" + # default_caption: "a person" # if caching text embeddings, if you don't have captions, this will get cached + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + resolution: [ 512, 768, 1024 ] # qwen image enjoys multiple resolutions + # a trigger word that can be cached with the text embeddings + # trigger_word: "optional trigger word" + train: + batch_size: 1 + # caching text embeddings is required for 32GB + cache_text_embeddings: true + # unload_text_encoder: true + + steps: 3000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + timestep_type: "weighted" + train_unet: true + train_text_encoder: false # probably won't work with qwen image + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample + # skip_first_sample: true + # uncomment to completely disable sampling + # disable_sampling: true + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "Qwen/Qwen-Image-Edit-2509" + arch: "qwen_image_edit_plus" + quantize: true + # to use the ARA use the | pipe to point to hf path, or a local path if you have one. + # 3bit is required for 32GB + qtype: "uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors" + quantize_te: true + qtype_te: "qfloat8" + low_vram: true + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + # you can provide up to 3 control images here + samples: + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + - prompt: "Do whatever with Image1 and Image2" + ctrl_img_1: "/path/to/image1.png" + ctrl_img_2: "/path/to/image2.png" + # ctrl_img_3: "/path/to/image3.png" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 3 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/config/examples/train_lora_wan22_14b_24gb.yaml b/config/examples/train_lora_wan22_14b_24gb.yaml new file mode 100644 index 000000000..966f184f4 --- /dev/null +++ b/config/examples/train_lora_wan22_14b_24gb.yaml @@ -0,0 +1,111 @@ +# this example focuses mainly for training Wan2.2 14b on images. It will work for video as well by increasing +# the number of frames in the dataset and samples. Training on and generating video is very VRAM intensive. +--- +job: extension +config: + # this name will be the folder and filename name + name: "my_first_wan22_14b_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # Use a trigger word if train.unload_text_encoder is true, however, if caching text embeddings, do not use a trigger word + # trigger_word: "p3r5on" + network: + type: "lora" + linear: 32 + linear_alpha: 32 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/or/video/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + # number of frames to extract from your video. It will automatically extract them evenly spaced + # set to 1 frame for images + num_frames: 1 + resolution: [ 512, 768, 1024] + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with wan + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + timestep_type: 'linear' + optimizer: "adamw8bit" + lr: 1e-4 + optimizer_params: + weight_decay: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + dtype: bf16 + + # IMPORTANT: this is for Wan 2.2 MOE. It will switch training one stage or the other every this many steps + switch_boundary_every: 10 + + # required for 24GB cards. You must do either unload_text_encoder or cache_text_embeddings but not both + + # this will encode your trigger word and use those embeddings for every image in the dataset, captions will be ignored + # unload_text_encoder: true + + # this will cache all captions in your dataset. + cache_text_embeddings: true + + model: + # huggingface model name or path, this one if bf16, vs the float32 of the official repo + name_or_path: "ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16" + arch: 'wan22_14b' + quantize: true + # This will pull and use a custom Accuracy Recovery Adapter to train at 4bit + qtype: "uint4|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors" + quantize_te: true + qtype_te: "qfloat8" + low_vram: true + model_kwargs: + # you can train high noise, low noise, or both. With low vram it will automatically unload the one not being trained. + train_high_noise: true + train_low_noise: true + sample: + sampler: "flowmatch" + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + # set to 1 for images + num_frames: 1 + fps: 16 + # samples take a long time. so use them sparingly + # samples will be animated webp files, if you don't see them animated, open in a browser. + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 3.5 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/extensions_built_in/concept_slider/ConceptSliderTrainer.py b/extensions_built_in/concept_slider/ConceptSliderTrainer.py new file mode 100644 index 000000000..6c9fce74e --- /dev/null +++ b/extensions_built_in/concept_slider/ConceptSliderTrainer.py @@ -0,0 +1,302 @@ +from collections import OrderedDict +from typing import Optional + +import torch + +from extensions_built_in.sd_trainer.DiffusionTrainer import DiffusionTrainer +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds +from toolkit.train_tools import get_torch_dtype + + +class ConceptSliderTrainerConfig: + def __init__(self, **kwargs): + self.guidance_strength: float = kwargs.get("guidance_strength", 3.0) + self.anchor_strength: float = kwargs.get("anchor_strength", 1.0) + self.positive_prompt: str = kwargs.get("positive_prompt", "") + self.negative_prompt: str = kwargs.get("negative_prompt", "") + self.target_class: str = kwargs.get("target_class", "") + self.anchor_class: Optional[str] = kwargs.get("anchor_class", None) + + +def norm_like_tensor(tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Normalize the tensor to have the same mean and std as the target tensor.""" + tensor_mean = tensor.mean() + tensor_std = tensor.std() + target_mean = target.mean() + target_std = target.std() + normalized_tensor = (tensor - tensor_mean) / ( + tensor_std + 1e-8 + ) * target_std + target_mean + return normalized_tensor + + +class ConceptSliderTrainer(DiffusionTrainer): + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + self.do_guided_loss = True + + self.slider: ConceptSliderTrainerConfig = ConceptSliderTrainerConfig( + **self.config.get("slider", {}) + ) + + self.positive_prompt = self.slider.positive_prompt + self.positive_prompt_embeds: Optional[PromptEmbeds] = None + self.negative_prompt = self.slider.negative_prompt + self.negative_prompt_embeds: Optional[PromptEmbeds] = None + self.target_class = self.slider.target_class + self.target_class_embeds: Optional[PromptEmbeds] = None + self.anchor_class = self.slider.anchor_class + self.anchor_class_embeds: Optional[PromptEmbeds] = None + + def hook_before_train_loop(self): + # do this before calling parent as it unloads the text encoder if requested + if self.is_caching_text_embeddings: + # make sure model is on cpu for this part so we don't oom. + self.sd.unet.to("cpu") + + # cache unconditional embeds (blank prompt) + with torch.no_grad(): + self.positive_prompt_embeds = ( + self.sd.encode_prompt( + [self.positive_prompt], + ) + .to(self.device_torch, dtype=self.sd.torch_dtype) + .detach() + ) + + self.target_class_embeds = ( + self.sd.encode_prompt( + [self.target_class], + ) + .to(self.device_torch, dtype=self.sd.torch_dtype) + .detach() + ) + + self.negative_prompt_embeds = ( + self.sd.encode_prompt( + [self.negative_prompt], + ) + .to(self.device_torch, dtype=self.sd.torch_dtype) + .detach() + ) + + if self.anchor_class is not None: + self.anchor_class_embeds = ( + self.sd.encode_prompt( + [self.anchor_class], + ) + .to(self.device_torch, dtype=self.sd.torch_dtype) + .detach() + ) + + # call parent + super().hook_before_train_loop() + + def get_guided_loss( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: "DataLoaderBatchDTO", + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + **kwargs, + ): + # todo for embeddings, we need to run without trigger words + was_unet_training = self.sd.unet.training + was_network_active = False + if self.network is not None: + was_network_active = self.network.is_active + self.network.is_active = False + + # do out prior preds first + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + self.sd.unet.eval() + noisy_latents = noisy_latents.to(self.device_torch, dtype=dtype).detach() + + batch_size = noisy_latents.shape[0] + + positive_embeds = concat_prompt_embeds( + [self.positive_prompt_embeds] * batch_size + ).to(self.device_torch, dtype=dtype) + target_class_embeds = concat_prompt_embeds( + [self.target_class_embeds] * batch_size + ).to(self.device_torch, dtype=dtype) + negative_embeds = concat_prompt_embeds( + [self.negative_prompt_embeds] * batch_size + ).to(self.device_torch, dtype=dtype) + + if self.anchor_class_embeds is not None: + anchor_embeds = concat_prompt_embeds( + [self.anchor_class_embeds] * batch_size + ).to(self.device_torch, dtype=dtype) + + if self.anchor_class_embeds is not None: + # if we have an anchor, do it + combo_embeds = concat_prompt_embeds( + [ + positive_embeds, + target_class_embeds, + negative_embeds, + anchor_embeds, + ] + ) + num_embeds = 4 + else: + combo_embeds = concat_prompt_embeds( + [positive_embeds, target_class_embeds, negative_embeds] + ) + num_embeds = 3 + + # do them in one batch, VRAM should handle it since we are no grad + combo_pred = self.sd.predict_noise( + latents=torch.cat([noisy_latents] * num_embeds, dim=0), + conditional_embeddings=combo_embeds, + timestep=torch.cat([timesteps] * num_embeds, dim=0), + guidance_scale=1.0, + guidance_embedding_scale=1.0, + batch=batch, + ) + + if self.anchor_class_embeds is not None: + positive_pred, neutral_pred, negative_pred, anchor_target = ( + combo_pred.chunk(4, dim=0) + ) + else: + anchor_target = None + positive_pred, neutral_pred, negative_pred = combo_pred.chunk(3, dim=0) + + # calculate the targets + guidance_scale = self.slider.guidance_strength + + # enhance_positive_target = neutral_pred + guidance_scale * ( + # positive_pred - negative_pred + # ) + # enhance_negative_target = neutral_pred + guidance_scale * ( + # negative_pred - positive_pred + # ) + # erase_negative_target = neutral_pred - guidance_scale * ( + # negative_pred - positive_pred + # ) + # erase_positive_target = neutral_pred - guidance_scale * ( + # positive_pred - negative_pred + # ) + + positive = (positive_pred - neutral_pred) - (negative_pred - neutral_pred) + negative = (negative_pred - neutral_pred) - (positive_pred - neutral_pred) + + enhance_positive_target = neutral_pred + guidance_scale * positive + enhance_negative_target = neutral_pred + guidance_scale * negative + erase_negative_target = neutral_pred - guidance_scale * negative + erase_positive_target = neutral_pred - guidance_scale * positive + + # normalize to neutral std/mean + enhance_positive_target = norm_like_tensor( + enhance_positive_target, neutral_pred + ) + enhance_negative_target = norm_like_tensor( + enhance_negative_target, neutral_pred + ) + erase_negative_target = norm_like_tensor( + erase_negative_target, neutral_pred + ) + erase_positive_target = norm_like_tensor( + erase_positive_target, neutral_pred + ) + + if was_unet_training: + self.sd.unet.train() + + # restore network + if self.network is not None: + self.network.is_active = was_network_active + + if self.anchor_class_embeds is not None: + # do a grad inference with our target prompt + embeds = concat_prompt_embeds([target_class_embeds, anchor_embeds]).to( + self.device_torch, dtype=dtype + ) + + noisy_latents = torch.cat([noisy_latents, noisy_latents], dim=0).to( + self.device_torch, dtype=dtype + ) + timesteps = torch.cat([timesteps, timesteps], dim=0) + else: + embeds = target_class_embeds.to(self.device_torch, dtype=dtype) + + # do positive first + self.network.set_multiplier(1.0) + pred = self.sd.predict_noise( + latents=noisy_latents, + conditional_embeddings=embeds, + timestep=timesteps, + guidance_scale=1.0, + guidance_embedding_scale=1.0, + batch=batch, + ) + + if self.anchor_class_embeds is not None: + class_pred, anchor_pred = pred.chunk(2, dim=0) + else: + class_pred = pred + anchor_pred = None + + # enhance positive loss + enhance_loss = torch.nn.functional.mse_loss(class_pred, enhance_positive_target) + + erase_loss = torch.nn.functional.mse_loss(class_pred, erase_negative_target) + + if anchor_target is None: + anchor_loss = torch.zeros_like(erase_loss) + else: + anchor_loss = torch.nn.functional.mse_loss(anchor_pred, anchor_target) + + anchor_loss = anchor_loss * self.slider.anchor_strength + + # send backward now because gradient checkpointing needs network polarity intact + total_pos_loss = (enhance_loss + erase_loss + anchor_loss) / 3.0 + total_pos_loss.backward() + total_pos_loss = total_pos_loss.detach() + + # now do negative + self.network.set_multiplier(-1.0) + pred = self.sd.predict_noise( + latents=noisy_latents, + conditional_embeddings=embeds, + timestep=timesteps, + guidance_scale=1.0, + guidance_embedding_scale=1.0, + batch=batch, + ) + + if self.anchor_class_embeds is not None: + class_pred, anchor_pred = pred.chunk(2, dim=0) + else: + class_pred = pred + anchor_pred = None + + # enhance negative loss + enhance_loss = torch.nn.functional.mse_loss(class_pred, enhance_negative_target) + erase_loss = torch.nn.functional.mse_loss(class_pred, erase_positive_target) + + if anchor_target is None: + anchor_loss = torch.zeros_like(erase_loss) + else: + anchor_loss = torch.nn.functional.mse_loss(anchor_pred, anchor_target) + anchor_loss = anchor_loss * self.slider.anchor_strength + total_neg_loss = (enhance_loss + erase_loss + anchor_loss) / 3.0 + total_neg_loss.backward() + total_neg_loss = total_neg_loss.detach() + + self.network.set_multiplier(1.0) + + total_loss = (total_pos_loss + total_neg_loss) / 2.0 + + # add a grad so backward works right + total_loss.requires_grad_(True) + return total_loss diff --git a/extensions_built_in/concept_slider/__init__.py b/extensions_built_in/concept_slider/__init__.py new file mode 100644 index 000000000..7a624b185 --- /dev/null +++ b/extensions_built_in/concept_slider/__init__.py @@ -0,0 +1,26 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# This is for generic training (LoRA, Dreambooth, FineTuning) +class ConceptSliderTrainerTrainer(Extension): + # uid must be unique, it is how the extension is identified + uid = "concept_slider" + + # name is the name of the extension for printing + name = "Concept Slider Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ConceptSliderTrainer import ConceptSliderTrainer + + return ConceptSliderTrainer + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + ConceptSliderTrainerTrainer +] diff --git a/extensions_built_in/diffusion_models/__init__.py b/extensions_built_in/diffusion_models/__init__.py index 5682bc462..f7d874e54 100644 --- a/extensions_built_in/diffusion_models/__init__.py +++ b/extensions_built_in/diffusion_models/__init__.py @@ -1,14 +1,15 @@ -from .chroma import ChromaModel +from .chroma import ChromaModel, ChromaRadianceModel from .hidream import HidreamModel, HidreamE1Model from .f_light import FLiteModel from .omnigen2 import OmniGen2Model from .flux_kontext import FluxKontextModel from .wan22 import Wan225bModel, Wan2214bModel, Wan2214bI2VModel -from .qwen_image import QwenImageModel, QwenImageEditModel +from .qwen_image import QwenImageModel, QwenImageEditModel, QwenImageEditPlusModel AI_TOOLKIT_MODELS = [ # put a list of models here ChromaModel, + ChromaRadianceModel, HidreamModel, HidreamE1Model, FLiteModel, @@ -19,4 +20,5 @@ Wan2214bModel, QwenImageModel, QwenImageEditModel, + QwenImageEditPlusModel, ] diff --git a/extensions_built_in/diffusion_models/chroma/__init__.py b/extensions_built_in/diffusion_models/chroma/__init__.py index b20e2f40c..c34866a6a 100644 --- a/extensions_built_in/diffusion_models/chroma/__init__.py +++ b/extensions_built_in/diffusion_models/chroma/__init__.py @@ -1 +1,2 @@ -from .chroma_model import ChromaModel \ No newline at end of file +from .chroma_model import ChromaModel +from .chroma_radiance_model import ChromaRadianceModel \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/chroma/chroma_model.py b/extensions_built_in/diffusion_models/chroma/chroma_model.py index 9bf3e51e2..236d9508b 100644 --- a/extensions_built_in/diffusion_models/chroma/chroma_model.py +++ b/extensions_built_in/diffusion_models/chroma/chroma_model.py @@ -15,7 +15,7 @@ from optimum.quanto import freeze, QTensor from toolkit.util.quantize import quantize, get_qtype from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer -from .pipeline import ChromaPipeline +from .pipeline import ChromaPipeline, prepare_latent_image_ids from einops import rearrange, repeat import random import torch.nn.functional as F @@ -324,12 +324,19 @@ def get_noise_prediction( ph=2, pw=2 ) - - img_ids = torch.zeros(h // 2, w // 2, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] - img_ids = repeat(img_ids, "h w c -> b (h w) c", - b=bs).to(self.device_torch) + + img_ids = prepare_latent_image_ids( + bs, + h, + w, + patch_size=2 + ).to(device=self.device_torch) + + # img_ids = torch.zeros(h // 2, w // 2, 3) + # img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + # img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + # img_ids = repeat(img_ids, "h w c -> b (h w) c", + # b=bs).to(self.device_torch) txt_ids = torch.zeros( bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) diff --git a/extensions_built_in/diffusion_models/chroma/chroma_radiance_model.py b/extensions_built_in/diffusion_models/chroma/chroma_radiance_model.py new file mode 100644 index 000000000..333600e8f --- /dev/null +++ b/extensions_built_in/diffusion_models/chroma/chroma_radiance_model.py @@ -0,0 +1,445 @@ +import os +from typing import TYPE_CHECKING + +import torch +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from diffusers import AutoencoderKL +# from toolkit.pixel_shuffle_encoder import AutoencoderPixelMixer +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.accelerator import unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype +from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer +from .pipeline import ChromaPipeline, prepare_latent_image_ids +from einops import rearrange, repeat +import random +import torch.nn.functional as F +from .src.radiance import Chroma, chroma_params +from safetensors.torch import load_file, save_file +from toolkit.metadata import get_meta_for_safetensors +from toolkit.models.FakeVAE import FakeVAE +import huggingface_hub + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True +} + +class FakeConfig: + # for diffusers compatability + def __init__(self): + self.attention_head_dim = 128 + self.guidance_embeds = True + self.in_channels = 64 + self.joint_attention_dim = 4096 + self.num_attention_heads = 24 + self.num_layers = 19 + self.num_single_layers = 38 + self.patch_size = 1 + +class FakeCLIP(torch.nn.Module): + def __init__(self): + super().__init__() + self.dtype = torch.bfloat16 + self.device = 'cuda' + self.text_model = None + self.tokenizer = None + self.model_max_length = 77 + + def forward(self, *args, **kwargs): + return torch.zeros(1, 1, 1).to(self.device) + + +class ChromaRadianceModel(BaseModel): + arch = "chroma_radiance" + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__( + device, + model_config, + dtype, + custom_pipeline, + noise_scheduler, + **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['Chroma'] + + # static method to get the noise scheduler + @staticmethod + def get_train_scheduler(): + return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + + def get_bucket_divisibility(self): + # return the bucket divisibility for the model + return 32 + + def load_model(self): + dtype = self.torch_dtype + + # will be updated if we detect a existing checkpoint in training folder + model_path = self.model_config.name_or_path + + if model_path == "lodestones/Chroma": + print("Looking for latest Chroma checkpoint") + # get the latest checkpoint + files_list = huggingface_hub.list_repo_files(model_path) + print(files_list) + latest_version = 28 # current latest version at time of writing + while True: + if f"chroma-unlocked-v{latest_version}.safetensors" not in files_list: + latest_version -= 1 + break + else: + latest_version += 1 + print(f"Using latest Chroma version: v{latest_version}") + + # make sure we have it + model_path = huggingface_hub.hf_hub_download( + repo_id=model_path, + filename=f"chroma-unlocked-v{latest_version}.safetensors", + ) + elif model_path.startswith("lodestones/Chroma/v"): + # get the version number + version = model_path.split("/")[-1].split("v")[-1] + print(f"Using Chroma version: v{version}") + # make sure we have it + model_path = huggingface_hub.hf_hub_download( + repo_id='lodestones/Chroma', + filename=f"chroma-unlocked-v{version}.safetensors", + ) + elif model_path.startswith("lodestones/Chroma1-"): + # will have a file in the repo that is Chroma1-whatever.safetensors + model_path = huggingface_hub.hf_hub_download( + repo_id=model_path, + filename=f"{model_path.split('/')[-1]}.safetensors", + ) + + else: + # check if the model path is a local file + if os.path.exists(model_path): + print(f"Using local model: {model_path}") + else: + raise ValueError(f"Model path {model_path} does not exist") + + # extras_path = 'black-forest-labs/FLUX.1-schnell' + # schnell model is gated now, use flex instead + extras_path = 'ostris/Flex.1-alpha' + + self.print_and_status_update("Loading transformer") + + if model_path.endswith('.pth') or model_path.endswith('.pt'): + chroma_state_dict = torch.load(model_path, map_location='cpu', weights_only=True) + else: + chroma_state_dict = load_file(model_path, 'cpu') + + # determine number of double and single blocks + double_blocks = 0 + single_blocks = 0 + for key in chroma_state_dict.keys(): + if "double_blocks" in key: + block_num = int(key.split(".")[1]) + 1 + if block_num > double_blocks: + double_blocks = block_num + elif "single_blocks" in key: + block_num = int(key.split(".")[1]) + 1 + if block_num > single_blocks: + single_blocks = block_num + print(f"Double Blocks: {double_blocks}") + print(f"Single Blocks: {single_blocks}") + + chroma_params.depth = double_blocks + chroma_params.depth_single_blocks = single_blocks + transformer = Chroma(chroma_params) + + # add dtype, not sure why it doesnt have it + transformer.dtype = dtype + # load the state dict into the model + transformer.load_state_dict(chroma_state_dict) + + transformer.to(self.quantize_device, dtype=dtype) + + transformer.config = FakeConfig() + transformer.config.num_layers = double_blocks + transformer.config.num_single_layers = single_blocks + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = get_qtype(self.model_config.qtype) + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, + **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + self.print_and_status_update("Loading T5") + tokenizer_2 = T5TokenizerFast.from_pretrained( + extras_path, subfolder="tokenizer_2", torch_dtype=dtype + ) + text_encoder_2 = T5EncoderModel.from_pretrained( + extras_path, subfolder="text_encoder_2", torch_dtype=dtype + ) + text_encoder_2.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing T5") + quantize(text_encoder_2, weights=get_qtype( + self.model_config.qtype)) + freeze(text_encoder_2) + flush() + + # self.print_and_status_update("Loading CLIP") + text_encoder = FakeCLIP() + tokenizer = FakeCLIP() + text_encoder.to(self.device_torch, dtype=dtype) + + self.noise_scheduler = ChromaRadianceModel.get_train_scheduler() + + self.print_and_status_update("Loading VAE") + # vae = AutoencoderKL.from_pretrained( + # extras_path, + # subfolder="vae", + # torch_dtype=dtype + # ) + vae = FakeVAE() + vae = vae.to(self.device_torch, dtype=dtype) + + self.print_and_status_update("Making pipe") + + pipe: ChromaPipeline = ChromaPipeline( + scheduler=self.noise_scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + is_radiance=True, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + text_encoder[1].to(self.device_torch) + text_encoder[1].requires_grad_(False) + text_encoder[1].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + + # save it to the model class + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") + + def get_generation_pipeline(self): + scheduler = ChromaRadianceModel.get_train_scheduler() + pipeline = ChromaPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + text_encoder_2=unwrap_model(self.text_encoder[1]), + tokenizer_2=self.tokenizer[1], + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + is_radiance=True, + ) + + # pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: ChromaPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + + extra['negative_prompt_embeds'] = unconditional_embeds.text_embeds + extra['negative_prompt_attn_mask'] = unconditional_embeds.attention_mask + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + prompt_attn_mask=conditional_embeds.attention_mask, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + with torch.no_grad(): + bs, c, h, w = latent_model_input.shape + + img_ids = prepare_latent_image_ids( + bs, h, w, patch_size=16 + ).to(self.device_torch) + + txt_ids = torch.zeros( + bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) + + guidance = torch.full([1], 0, device=self.device_torch, dtype=torch.float32) + guidance = guidance.expand(bs) + + cast_dtype = self.unet.dtype + + noise_pred = self.unet( + img=latent_model_input.to( + self.device_torch, cast_dtype + ), + img_ids=img_ids, + txt=text_embeddings.text_embeds.to( + self.device_torch, cast_dtype + ), + txt_ids=txt_ids, + txt_mask=text_embeddings.attention_mask.to( + self.device_torch, cast_dtype + ), + timesteps=timestep / 1000, + guidance=guidance + ) + + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + + return noise_pred + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + if isinstance(prompt, str): + prompts = [prompt] + else: + prompts = prompt + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + max_length = 512 + + device = self.text_encoder[1].device + dtype = self.text_encoder[1].dtype + + # T5 + text_inputs = self.tokenizer[1]( + prompts, + padding="max_length", + max_length=max_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_embeds = self.text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder[1].dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + prompt_attention_mask = text_inputs["attention_mask"] + + pe = PromptEmbeds( + prompt_embeds + ) + pe.attention_mask = prompt_attention_mask + return pe + + def get_model_has_grad(self): + # return from a weight if it has grad + return False + def get_te_has_grad(self): + # return from a weight if it has grad + return False + + def save_model(self, output_path, meta, save_dtype): + if not output_path.endswith(".safetensors"): + output_path = output_path + ".safetensors" + # only save the unet + transformer: Chroma = unwrap_model(self.model) + state_dict = transformer.state_dict() + save_dict = {} + for k, v in state_dict.items(): + if isinstance(v, QTensor): + v = v.dequantize() + save_dict[k] = v.clone().to('cpu', dtype=save_dtype) + + meta = get_meta_for_safetensors(meta, name='chroma') + save_file(save_dict, output_path, metadata=meta) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + batch = kwargs.get('batch') + return (noise - batch.latents).detach() + + def convert_lora_weights_before_save(self, state_dict): + # currently starte with transformer. but needs to start with diffusion_model. for comfyui + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("transformer.", "diffusion_model.") + new_sd[new_key] = value + return new_sd + + def convert_lora_weights_before_load(self, state_dict): + # saved as diffusion_model. but needs to be transformer. for ai-toolkit + new_sd = {} + for key, value in state_dict.items(): + new_key = key.replace("diffusion_model.", "transformer.") + new_sd[new_key] = value + return new_sd + + def get_base_model_version(self): + return "chroma_radiance" diff --git a/extensions_built_in/diffusion_models/chroma/pipeline.py b/extensions_built_in/diffusion_models/chroma/pipeline.py index 215be798c..5a76a7131 100644 --- a/extensions_built_in/diffusion_models/chroma/pipeline.py +++ b/extensions_built_in/diffusion_models/chroma/pipeline.py @@ -6,6 +6,7 @@ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput from diffusers.utils import is_torch_xla_available +from diffusers.utils.torch_utils import randn_tensor if is_torch_xla_available(): @@ -16,7 +17,134 @@ XLA_AVAILABLE = False +def prepare_latent_image_ids(batch_size, height, width, patch_size=2, max_offset=0): + """ + Generates positional embeddings for a latent image. + + Args: + batch_size (int): The number of images in the batch. + height (int): The height of the image. + width (int): The width of the image. + patch_size (int, optional): The size of the patches. Defaults to 2. + max_offset (int, optional): The maximum random offset to apply. Defaults to 0. + + Returns: + torch.Tensor: A tensor containing the positional embeddings. + """ + # the random pos embedding helps generalize to larger res without training at large res + # pos embedding for rope, 2d pos embedding, corner embedding and not center based + latent_image_ids = torch.zeros(height // patch_size, width // patch_size, 3) + + # Add positional encodings + latent_image_ids[..., 1] = ( + latent_image_ids[..., 1] + torch.arange(height // patch_size)[:, None] + ) + latent_image_ids[..., 2] = ( + latent_image_ids[..., 2] + torch.arange(width // patch_size)[None, :] + ) + + # Add random offset if specified + if max_offset > 0: + offset_y = torch.randint(0, max_offset + 1, (1,)).item() + offset_x = torch.randint(0, max_offset + 1, (1,)).item() + latent_image_ids[..., 1] += offset_y + latent_image_ids[..., 2] += offset_x + + + ( + latent_image_id_height, + latent_image_id_width, + latent_image_id_channels, + ) = latent_image_ids.shape + + # Reshape for batch + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, + latent_image_id_height * latent_image_id_width, + latent_image_id_channels, + ) + + return latent_image_ids + + class ChromaPipeline(FluxPipeline): + def __init__( + self, + scheduler, + vae, + text_encoder, + tokenizer, + text_encoder_2, + tokenizer_2, + transformer, + image_encoder = None, + feature_extractor = None, + is_radiance: bool = False, + ): + super().__init__( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + transformer=transformer, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.is_radiance = is_radiance + self.vae_scale_factor = 8 if not is_radiance else 1 + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = prepare_latent_image_ids( + batch_size, + height, + width, + patch_size=2 if not self.is_radiance else 16 + ).to(device=device, dtype=dtype) + # latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + if not self.is_radiance: + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + # latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + latent_image_ids = prepare_latent_image_ids( + batch_size, + height, + width, + patch_size=2 if not self.is_radiance else 16 + ).to(device=device, dtype=dtype) + + return latents, latent_image_ids + def __call__( self, prompt: Union[str, List[str]] = None, @@ -70,6 +198,8 @@ def __call__( # 4. Prepare latent variables num_channels_latents = 64 // 4 + if self.is_radiance: + num_channels_latents = 3 latents, latent_image_ids = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -82,8 +212,8 @@ def __call__( ) # extend img ids to match batch size - latent_image_ids = latent_image_ids.unsqueeze(0) - latent_image_ids = torch.cat([latent_image_ids] * batch_size, dim=0) + # latent_image_ids = latent_image_ids.unsqueeze(0) + # latent_image_ids = torch.cat([latent_image_ids] * batch_size, dim=0) # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) @@ -180,8 +310,9 @@ def __call__( image = latents else: - latents = self._unpack_latents( - latents, height, width, self.vae_scale_factor) + if not self.is_radiance: + latents = self._unpack_latents( + latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + \ self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] diff --git a/extensions_built_in/diffusion_models/chroma/src/layers.py b/extensions_built_in/diffusion_models/chroma/src/layers.py index 726ec6a02..030a57cdb 100644 --- a/extensions_built_in/diffusion_models/chroma/src/layers.py +++ b/extensions_built_in/diffusion_models/chroma/src/layers.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from .math import attention, rope +from functools import lru_cache class EmbedND(nn.Module): @@ -88,7 +89,7 @@ def forward(self, x: Tensor): # return self._forward(x) -def distribute_modulations(tensor: torch.Tensor): +def distribute_modulations(tensor: torch.Tensor, depth_single_blocks, depth_double_blocks): """ Distributes slices of the tensor into the block_dict as ModulationOut objects. @@ -102,25 +103,25 @@ def distribute_modulations(tensor: torch.Tensor): # HARD CODED VALUES! lookup table for the generated vectors # TODO: move this into chroma config! # Add 38 single mod blocks - for i in range(38): + for i in range(depth_single_blocks): key = f"single_blocks.{i}.modulation.lin" block_dict[key] = None # Add 19 image double blocks - for i in range(19): + for i in range(depth_double_blocks): key = f"double_blocks.{i}.img_mod.lin" block_dict[key] = None # Add 19 text double blocks - for i in range(19): + for i in range(depth_double_blocks): key = f"double_blocks.{i}.txt_mod.lin" block_dict[key] = None # Add the final layer block_dict["final_layer.adaLN_modulation.1"] = None # 6.2b version - block_dict["lite_double_blocks.4.img_mod.lin"] = None - block_dict["lite_double_blocks.4.txt_mod.lin"] = None + # block_dict["lite_double_blocks.4.img_mod.lin"] = None + # block_dict["lite_double_blocks.4.txt_mod.lin"] = None idx = 0 # Index to keep track of the vector slices @@ -173,6 +174,219 @@ def distribute_modulations(tensor: torch.Tensor): return block_dict + +class NerfEmbedder(nn.Module): + """ + An embedder module that combines input features with a 2D positional + encoding that mimics the Discrete Cosine Transform (DCT). + + This module takes an input tensor of shape (B, P^2, C), where P is the + patch size, and enriches it with positional information before projecting + it to a new hidden size. + """ + def __init__(self, in_channels, hidden_size_input, max_freqs): + """ + Initializes the NerfEmbedder. + + Args: + in_channels (int): The number of channels in the input tensor. + hidden_size_input (int): The desired dimension of the output embedding. + max_freqs (int): The number of frequency components to use for both + the x and y dimensions of the positional encoding. + The total number of positional features will be max_freqs^2. + """ + super().__init__() + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + + # A linear layer to project the concatenated input features and + # positional encodings to the final output dimension. + self.embedder = nn.Sequential( + nn.Linear(in_channels + max_freqs**2, hidden_size_input) + ) + + @lru_cache(maxsize=4) + def fetch_pos(self, patch_size, device, dtype): + """ + Generates and caches 2D DCT-like positional embeddings for a given patch size. + + The LRU cache is a performance optimization that avoids recomputing the + same positional grid on every forward pass. + + Args: + patch_size (int): The side length of the square input patch. + device: The torch device to create the tensors on. + dtype: The torch dtype for the tensors. + + Returns: + A tensor of shape (1, patch_size^2, max_freqs^2) containing the + positional embeddings. + """ + # Create normalized 1D coordinate grids from 0 to 1. + pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + + # Create a 2D meshgrid of coordinates. + pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") + + # Reshape positions to be broadcastable with frequencies. + # Shape becomes (patch_size^2, 1, 1). + pos_x = pos_x.reshape(-1, 1, 1) + pos_y = pos_y.reshape(-1, 1, 1) + + # Create a 1D tensor of frequency values from 0 to max_freqs-1. + freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device) + + # Reshape frequencies to be broadcastable for creating 2D basis functions. + # freqs_x shape: (1, max_freqs, 1) + # freqs_y shape: (1, 1, max_freqs) + freqs_x = freqs[None, :, None] + freqs_y = freqs[None, None, :] + + # A custom weighting coefficient, not part of standard DCT. + # This seems to down-weight the contribution of higher-frequency interactions. + coeffs = (1 + freqs_x * freqs_y) ** -1 + + # Calculate the 1D cosine basis functions for x and y coordinates. + # This is the core of the DCT formulation. + dct_x = torch.cos(pos_x * freqs_x * torch.pi) + dct_y = torch.cos(pos_y * freqs_y * torch.pi) + + # Combine the 1D basis functions to create 2D basis functions by element-wise + # multiplication, and apply the custom coefficients. Broadcasting handles the + # combination of all (pos_x, freqs_x) with all (pos_y, freqs_y). + # The result is flattened into a feature vector for each position. + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2) + + return dct + + def forward(self, inputs): + """ + Forward pass for the embedder. + + Args: + inputs (Tensor): The input tensor of shape (B, P^2, C). + + Returns: + Tensor: The output tensor of shape (B, P^2, hidden_size_input). + """ + # Get the batch size, number of pixels, and number of channels. + B, P2, C = inputs.shape + # Store the original dtype to cast back to at the end. + original_dtype = inputs.dtype + # Force all operations within this module to run in fp32. + with torch.autocast("cuda", enabled=False): + # Infer the patch side length from the number of pixels (P^2). + patch_size = int(P2 ** 0.5) + + inputs = inputs.float() + # Fetch the pre-computed or cached positional embeddings. + dct = self.fetch_pos(patch_size, inputs.device, torch.float32) + + # Repeat the positional embeddings for each item in the batch. + dct = dct.repeat(B, 1, 1) + + # Concatenate the original input features with the positional embeddings + # along the feature dimension. + inputs = torch.cat([inputs, dct], dim=-1) + + # Project the combined tensor to the target hidden size. + inputs = self.embedder.float()(inputs) + + return inputs.to(original_dtype) + + + +class NerfGLUBlock(nn.Module): + """ + A NerfBlock using a Gated Linear Unit (GLU) like MLP. + """ + def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, use_compiled): + super().__init__() + # The total number of parameters for the MLP is increased to accommodate + # the gate, value, and output projection matrices. + # We now need to generate parameters for 3 matrices. + total_params = 3 * hidden_size_x**2 * mlp_ratio + self.param_generator = nn.Linear(hidden_size_s, total_params) + self.norm = RMSNorm(hidden_size_x, use_compiled) + self.mlp_ratio = mlp_ratio + # nn.init.zeros_(self.param_generator.weight) + # nn.init.zeros_(self.param_generator.bias) + + + def forward(self, x, s): + batch_size, num_x, hidden_size_x = x.shape + mlp_params = self.param_generator(s) + + # Split the generated parameters into three parts for the gate, value, and output projection. + fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1) + + # Reshape the parameters into matrices for batch matrix multiplication. + fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x) + + # Normalize the generated weight matrices as in the original implementation. + fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2) + fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2) + fc2 = torch.nn.functional.normalize(fc2, dim=-2) + + res_x = x + x = self.norm(x) + + # Apply the final output projection. + x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2) + + x = x + res_x + return x + + +class NerfFinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, use_compiled): + super().__init__() + self.norm = RMSNorm(hidden_size, use_compiled=use_compiled) + self.linear = nn.Linear(hidden_size, out_channels) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x): + x = self.norm(x) + x = self.linear(x) + return x + + +class NerfFinalLayerConv(nn.Module): + def __init__(self, hidden_size, out_channels, use_compiled): + super().__init__() + self.norm = RMSNorm(hidden_size, use_compiled=use_compiled) + + # replace nn.Linear with nn.Conv2d since linear is just pointwise conv + self.conv = nn.Conv2d( + in_channels=hidden_size, + out_channels=out_channels, + kernel_size=3, + padding=1 + ) + nn.init.zeros_(self.conv.weight) + nn.init.zeros_(self.conv.bias) + + def forward(self, x): + # shape: [N, C, H, W] ! + # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. + # So, we permute the dimensions to make the channel dimension the last one. + x_permuted = x.permute(0, 2, 3, 1) # Shape becomes [N, H, W, C] + + # Apply normalization on the feature/channel dimension + x_norm = self.norm(x_permuted) + + # Permute back to the original dimension order for the convolution + x_norm_permuted = x_norm.permute(0, 3, 1, 2) # Shape becomes [N, C, H, W] + + # Apply the 3x3 convolution + x = self.conv(x_norm_permuted) + return x + + class Approximator(nn.Module): def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=4): super().__init__() diff --git a/extensions_built_in/diffusion_models/chroma/src/model.py b/extensions_built_in/diffusion_models/chroma/src/model.py index 33cdbe62f..ebdf69d93 100644 --- a/extensions_built_in/diffusion_models/chroma/src/model.py +++ b/extensions_built_in/diffusion_models/chroma/src/model.py @@ -156,13 +156,19 @@ def __init__(self, params: ChromaParams): ) # TODO: move this hardcoded value to config - self.mod_index_length = 344 + # single layer has 3 modulation vectors + # double layer has 6 modulation vectors for each expert + # final layer has 2 modulation vectors + self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2 + self.depth_single_blocks = params.depth_single_blocks + self.depth_double_blocks = params.depth # self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0) self.register_buffer( "mod_index", torch.tensor(list(range(self.mod_index_length)), device="cpu"), persistent=False, ) + self.approximator_in_dim = params.approximator_in_dim @property def device(self): @@ -213,7 +219,7 @@ def forward( # then and only then we could concatenate it together input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True)) - mod_vectors_dict = distribute_modulations(mod_vectors) + mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) diff --git a/extensions_built_in/diffusion_models/chroma/src/radiance.py b/extensions_built_in/diffusion_models/chroma/src/radiance.py new file mode 100644 index 000000000..d328f261e --- /dev/null +++ b/extensions_built_in/diffusion_models/chroma/src/radiance.py @@ -0,0 +1,380 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +import torch.utils.checkpoint as ckpt + +from .layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + SingleStreamBlock, + timestep_embedding, + Approximator, + distribute_modulations, + NerfEmbedder, + NerfFinalLayer, + NerfFinalLayerConv, + NerfGLUBlock +) + + +@dataclass +class ChromaParams: + in_channels: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + approximator_in_dim: int + approximator_depth: int + approximator_hidden_size: int + patch_size: int + nerf_hidden_size: int + nerf_mlp_ratio: int + nerf_depth: int + nerf_max_freqs: int + _use_compiled: bool + + +chroma_params = ChromaParams( + in_channels=3, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + approximator_in_dim=64, + approximator_depth=5, + approximator_hidden_size=5120, + patch_size=16, + nerf_hidden_size=64, + nerf_mlp_ratio=4, + nerf_depth=4, + nerf_max_freqs=8, + _use_compiled=False, +) + + +def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8): + """ + Modifies attention mask to allow attention to a few extra padding tokens. + + Args: + mask: Original attention mask (1 for tokens to attend to, 0 for masked tokens) + max_seq_length: Maximum sequence length of the model + num_extra_padding: Number of padding tokens to unmask + + Returns: + Modified mask + """ + # Get the actual sequence length from the mask + seq_length = mask.sum(dim=-1) + batch_size = mask.shape[0] + + modified_mask = mask.clone() + + for i in range(batch_size): + current_seq_len = int(seq_length[i].item()) + + # Only add extra padding tokens if there's room + if current_seq_len < max_seq_length: + # Calculate how many padding tokens we can unmask + available_padding = max_seq_length - current_seq_len + tokens_to_unmask = min(num_extra_padding, available_padding) + + # Unmask the specified number of padding tokens right after the sequence + modified_mask[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1 + + return modified_mask + + +class Chroma(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: ChromaParams): + super().__init__() + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + self.gradient_checkpointing = False + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError( + f"Got {params.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim + ) + # self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + # patchify ops + self.img_in_patch = nn.Conv2d( + params.in_channels, + params.hidden_size, + kernel_size=params.patch_size, + stride=params.patch_size, + bias=True + ) + nn.init.zeros_(self.img_in_patch.weight) + nn.init.zeros_(self.img_in_patch.bias) + # TODO: need proper mapping for this approximator output! + # currently the mapping is hardcoded in distribute_modulations function + self.distilled_guidance_layer = Approximator( + params.approximator_in_dim, + self.hidden_size, + params.approximator_hidden_size, + params.approximator_depth, + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + use_compiled=params._use_compiled, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + use_compiled=params._use_compiled, + ) + for _ in range(params.depth_single_blocks) + ] + ) + + # self.final_layer = LastLayer( + # self.hidden_size, + # 1, + # self.out_channels, + # use_compiled=params._use_compiled, + # ) + + # pixel channel concat with DCT + self.nerf_image_embedder = NerfEmbedder( + in_channels=params.in_channels, + hidden_size_input=params.nerf_hidden_size, + max_freqs=params.nerf_max_freqs + ) + + self.nerf_blocks = nn.ModuleList([ + NerfGLUBlock( + hidden_size_s=params.hidden_size, + hidden_size_x=params.nerf_hidden_size, + mlp_ratio=params.nerf_mlp_ratio, + use_compiled=params._use_compiled + ) for _ in range(params.nerf_depth) + ]) + # self.nerf_final_layer = NerfFinalLayer( + # params.nerf_hidden_size, + # out_channels=params.in_channels, + # use_compiled=params._use_compiled + # ) + self.nerf_final_layer_conv = NerfFinalLayerConv( + params.nerf_hidden_size, + out_channels=params.in_channels, + use_compiled=params._use_compiled + ) + # TODO: move this hardcoded value to config + # single layer has 3 modulation vectors + # double layer has 6 modulation vectors for each expert + # final layer has 2 modulation vectors + self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2 + self.depth_single_blocks = params.depth_single_blocks + self.depth_double_blocks = params.depth + # self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0) + self.register_buffer( + "mod_index", + torch.tensor(list(range(self.mod_index_length)), device="cpu"), + persistent=False, + ) + self.approximator_in_dim = params.approximator_in_dim + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def enable_gradient_checkpointing(self, enable: bool = True): + self.gradient_checkpointing = enable + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + txt_mask: Tensor, + timesteps: Tensor, + guidance: Tensor, + attn_padding: int = 1, + ) -> Tensor: + if img.ndim != 4: + raise ValueError("Input img tensor must be in [B, C, H, W] format.") + if txt.ndim != 3: + raise ValueError("Input txt tensors must have 3 dimensions.") + B, C, H, W = img.shape + + # gemini gogogo idk how to unfold and pack the patch properly :P + # Store the raw pixel values of each patch for the NeRF head later. + # unfold creates patches: [B, C * P * P, NumPatches] + nerf_pixels = nn.functional.unfold(img, kernel_size=self.params.patch_size, stride=self.params.patch_size) + nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] + + # partchify ops + img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P] + num_patches = img.shape[2] * img.shape[3] + # flatten into a sequence for the transformer. + img = img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden] + + txt = self.txt_in(txt) + + # TODO: + # need to fix grad accumulation issue here for now it's in no grad mode + # besides, i don't want to wash out the PFP that's trained on this model weights anyway + # the fan out operation here is deleting the backward graph + # alternatively doing forward pass for every block manually is doable but slow + # custom backward probably be better + with torch.no_grad(): + distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim//4) + # TODO: need to add toggle to omit this from schnell but that's not a priority + distil_guidance = timestep_embedding(guidance, self.approximator_in_dim//4) + # get all modulation index + modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim//2) + # we need to broadcast the modulation index here so each batch has all of the index + modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1) + # and we need to broadcast timestep and guidance along too + timestep_guidance = ( + torch.cat([distill_timestep, distil_guidance], dim=1) + .unsqueeze(1) + .repeat(1, self.mod_index_length, 1) + ) + # then and only then we could concatenate it together + input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) + mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True)) + mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + # compute mask + # assume max seq length from the batched input + + max_len = txt.shape[1] + + # mask + with torch.no_grad(): + txt_mask_w_padding = modify_mask_to_attend_padding( + txt_mask, max_len, attn_padding + ) + txt_img_mask = torch.cat( + [ + txt_mask_w_padding, + torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device), + ], + dim=1, + ) + txt_img_mask = txt_img_mask.float().T @ txt_img_mask.float() + txt_img_mask = ( + txt_img_mask[None, None, ...] + .repeat(txt.shape[0], self.num_heads, 1, 1) + .int() + .bool() + ) + # txt_mask_w_padding[txt_mask_w_padding==False] = True + + for i, block in enumerate(self.double_blocks): + # the guidance replaced by FFN output + img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] + txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] + double_mod = [img_mod, txt_mod] + + # just in case in different GPU for simple pipeline parallel + if torch.is_grad_enabled() and self.gradient_checkpointing: + img.requires_grad_(True) + img, txt = ckpt.checkpoint( + block, img, txt, pe, double_mod, txt_img_mask + ) + else: + img, txt = block( + img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask + ) + + img = torch.cat((txt, img), 1) + for i, block in enumerate(self.single_blocks): + single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] + if torch.is_grad_enabled() and self.gradient_checkpointing: + img.requires_grad_(True) + img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask) + else: + img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask) + img = img[:, txt.shape[1] :, ...] + + # final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] + # img = self.final_layer( + # img, distill_vec=final_mod + # ) # (N, T, patch_size ** 2 * out_channels) + + # aliasing + nerf_hidden = img + # reshape for per-patch processing + nerf_hidden = nerf_hidden.reshape(B * num_patches, self.params.hidden_size) + nerf_pixels = nerf_pixels.reshape(B * num_patches, C, self.params.patch_size**2).transpose(1, 2) + + # get DCT-encoded pixel embeddings [pixel-dct] + img_dct = self.nerf_image_embedder(nerf_pixels) + + # pass through the dynamic MLP blocks (the NeRF) + for i, block in enumerate(self.nerf_blocks): + if self.training: + img_dct = ckpt.checkpoint(block, img_dct, nerf_hidden) + else: + img_dct = block(img_dct, nerf_hidden) + + # final projection to get the output pixel values + # img_dct = self.nerf_final_layer(img_dct) # -> [B*NumPatches, P*P, C] + img_dct = self.nerf_final_layer_conv.norm(img_dct) + + # gemini gogogo idk how to fold this properly :P + # Reassemble the patches into the final image. + img_dct = img_dct.transpose(1, 2) # -> [B*NumPatches, C, P*P] + # Reshape to combine with batch dimension for fold + img_dct = img_dct.reshape(B, num_patches, -1) # -> [B, NumPatches, C*P*P] + img_dct = img_dct.transpose(1, 2) # -> [B, C*P*P, NumPatches] + img_dct = nn.functional.fold( + img_dct, + output_size=(H, W), + kernel_size=self.params.patch_size, + stride=self.params.patch_size + ) # [B, Hidden, H, W] + img_dct = self.nerf_final_layer_conv.conv(img_dct) + + return img_dct \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/qwen_image/__init__.py b/extensions_built_in/diffusion_models/qwen_image/__init__.py index d8b32a851..3ff1797bf 100644 --- a/extensions_built_in/diffusion_models/qwen_image/__init__.py +++ b/extensions_built_in/diffusion_models/qwen_image/__init__.py @@ -1,2 +1,3 @@ from .qwen_image import QwenImageModel -from .qwen_image_edit import QwenImageEditModel \ No newline at end of file +from .qwen_image_edit import QwenImageEditModel +from .qwen_image_edit_plus import QwenImageEditPlusModel diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py index 07dcde15a..edfe0c13d 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py @@ -284,11 +284,11 @@ def get_noise_prediction( txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() noise_pred = self.transformer( - hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), - timestep=timestep / 1000, + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype).detach(), + timestep=(timestep / 1000).detach(), guidance=None, - encoder_hidden_states=enc_hs, - encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=enc_hs.detach(), + encoder_hidden_states_mask=prompt_embeds_mask.detach(), img_shapes=img_shapes, txt_seq_lens=txt_seq_lens, return_dict=False, diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py new file mode 100644 index 000000000..222431f28 --- /dev/null +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py @@ -0,0 +1,312 @@ +import math +import torch +from .qwen_image import QwenImageModel +import os +from typing import TYPE_CHECKING, List, Optional +import yaml +from toolkit import train_tools +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from PIL import Image +from toolkit.models.base_model import BaseModel +from toolkit.basic import flush +from toolkit.prompt_utils import PromptEmbeds +from toolkit.samplers.custom_flowmatch_sampler import ( + CustomFlowMatchEulerDiscreteScheduler, +) +from toolkit.accelerator import get_accelerator, unwrap_model +from optimum.quanto import freeze, QTensor +from toolkit.util.quantize import quantize, get_qtype, quantize_model +import torch.nn.functional as F + +from diffusers import ( + QwenImageTransformer2DModel, + AutoencoderKLQwenImage, +) +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer +from tqdm import tqdm + + +if TYPE_CHECKING: + from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO + +try: + from diffusers import QwenImageEditPlusPipeline + from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, VAE_IMAGE_SIZE +except ImportError: + raise ImportError( + "Diffusers is out of date. Update diffusers to the latest version by doing 'pip uninstall diffusers' and then 'pip install -r requirements.txt'" + ) + + +class QwenImageEditPlusModel(QwenImageModel): + arch = "qwen_image_edit_plus" + _qwen_image_keep_visual = True + _qwen_pipeline = QwenImageEditPlusPipeline + + def __init__( + self, + device, + model_config: ModelConfig, + dtype="bf16", + custom_pipeline=None, + noise_scheduler=None, + **kwargs, + ): + super().__init__( + device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs + ) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ["QwenImageTransformer2DModel"] + + # set true for models that encode control image into text embeddings + self.encode_control_in_text_embeddings = True + # control images will come in as a list for encoding some things if true + self.has_multiple_control_images = True + # do not resize control images + self.use_raw_control_images = True + + def load_model(self): + super().load_model() + + def get_generation_pipeline(self): + scheduler = QwenImageModel.get_train_scheduler() + + pipeline: QwenImageEditPlusPipeline = QwenImageEditPlusPipeline( + scheduler=scheduler, + text_encoder=unwrap_model(self.text_encoder[0]), + tokenizer=self.tokenizer[0], + processor=self.processor, + vae=unwrap_model(self.vae), + transformer=unwrap_model(self.transformer), + ) + + pipeline = pipeline.to(self.device_torch) + + return pipeline + + def generate_single_image( + self, + pipeline: QwenImageEditPlusPipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + self.model.to(self.device_torch, dtype=self.torch_dtype) + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + control_img_list = [] + if gen_config.ctrl_img is not None: + control_img = Image.open(gen_config.ctrl_img) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + elif gen_config.ctrl_img_1 is not None: + control_img = Image.open(gen_config.ctrl_img_1) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + + if gen_config.ctrl_img_2 is not None: + control_img = Image.open(gen_config.ctrl_img_2) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + if gen_config.ctrl_img_3 is not None: + control_img = Image.open(gen_config.ctrl_img_3) + control_img = control_img.convert("RGB") + control_img_list.append(control_img) + + # flush for low vram if we are doing that + # flush_between_steps = self.model_config.low_vram + flush_between_steps = False + + # Fix a bug in diffusers/torch + def callback_on_step_end(pipe, i, t, callback_kwargs): + if flush_between_steps: + flush() + latents = callback_kwargs["latents"] + + return {"latents": latents} + + img = pipeline( + image=control_img_list, + prompt_embeds=conditional_embeds.text_embeds, + prompt_embeds_mask=conditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to( + self.device_torch, dtype=torch.int64 + ), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + true_cfg_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + callback_on_step_end=callback_on_step_end, + **extra, + ).images[0] + return img + + def condition_noisy_latents( + self, latents: torch.Tensor, batch: "DataLoaderBatchDTO" + ): + # we get the control image from the batch + return latents.detach() + + def get_prompt_embeds(self, prompt: str, control_images=None) -> PromptEmbeds: + # todo handle not caching text encoder + if self.pipeline.text_encoder.device != self.device_torch: + self.pipeline.text_encoder.to(self.device_torch) + + if control_images is not None and len(control_images) > 0: + for i in range(len(control_images)): + # control images are 0 - 1 scale, shape (bs, ch, height, width) + ratio = control_images[i].shape[2] / control_images[i].shape[3] + width = math.sqrt(CONDITION_IMAGE_SIZE * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + control_images[i] = F.interpolate( + control_images[i], size=(height, width), mode="bilinear" + ) + + prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt( + prompt, + image=control_images, + device=self.device_torch, + num_images_per_prompt=1, + ) + pe = PromptEmbeds(prompt_embeds) + pe.attention_mask = prompt_embeds_mask + return pe + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + batch: "DataLoaderBatchDTO" = None, + **kwargs, + ): + with torch.no_grad(): + batch_size, num_channels_latents, height, width = latent_model_input.shape + + # pack image tokens + latent_model_input = latent_model_input.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5) + latent_model_input = latent_model_input.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + + raw_packed_latents = latent_model_input + + img_h2, img_w2 = height // 2, width // 2 + + img_shapes = [ + [(1, img_h2, img_w2)] + ] * batch_size + + # pack controls + if batch is None: + raise ValueError("Batch is required for QwenImageEditPlusModel") + + # split the latents into batch items so we can concat the controls + packed_latents_list = torch.chunk(latent_model_input, batch_size, dim=0) + packed_latents_with_controls_list = [] + + if batch.control_tensor_list is not None: + if len(batch.control_tensor_list) != batch_size: + raise ValueError("Control tensor list length does not match batch size") + b = 0 + for control_tensor_list in batch.control_tensor_list: + # control tensor list is a list of tensors for this batch item + controls = [] + # pack control + for control_img in control_tensor_list: + # control images are 0 - 1 scale, shape (1, ch, height, width) + control_img = control_img.to(self.device_torch, dtype=self.torch_dtype) + # if it is only 3 dim, add batch dim + if len(control_img.shape) == 3: + control_img = control_img.unsqueeze(0) + ratio = control_img.shape[2] / control_img.shape[3] + c_width = math.sqrt(VAE_IMAGE_SIZE * ratio) + c_height = c_width / ratio + + c_width = round(c_width / 32) * 32 + c_height = round(c_height / 32) * 32 + + control_img = F.interpolate( + control_img, size=(c_height, c_width), mode="bilinear" + ) + + # scale to -1 to 1 + control_img = control_img * 2 - 1 + + control_latent = self.encode_images( + control_img, + device=self.device_torch, + dtype=self.torch_dtype, + ) + + clb, cl_num_channels_latents, cl_height, cl_width = control_latent.shape + + control = control_latent.view( + 1, cl_num_channels_latents, cl_height // 2, 2, cl_width // 2, 2 + ) + control = control.permute(0, 2, 4, 1, 3, 5) + control = control.reshape( + 1, (cl_height // 2) * (cl_width // 2), num_channels_latents * 4 + ) + + img_shapes[b].append((1, cl_height // 2, cl_width // 2)) + controls.append(control) + + # stack controls on dim 1 + control = torch.cat(controls, dim=1).to(packed_latents_list[b].device, dtype=packed_latents_list[b].dtype) + # concat with latents + packed_latents_with_control = torch.cat([packed_latents_list[b], control], dim=1) + + packed_latents_with_controls_list.append(packed_latents_with_control) + + b += 1 + + latent_model_input = torch.cat(packed_latents_with_controls_list, dim=0) + + prompt_embeds_mask = text_embeddings.attention_mask.to( + self.device_torch, dtype=torch.int64 + ) + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() + enc_hs = text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype) + prompt_embeds_mask = text_embeddings.attention_mask.to( + self.device_torch, dtype=torch.int64 + ) + + noise_pred = self.transformer( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype).detach(), + timestep=(timestep / 1000).detach(), + guidance=None, + encoder_hidden_states=enc_hs.detach(), + encoder_hidden_states_mask=prompt_embeds_mask.detach(), + img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, + return_dict=False, + **kwargs, + )[0] + + noise_pred = noise_pred[:, : raw_packed_latents.size(1)] + + # unpack + noise_pred = noise_pred.view( + batch_size, height // 2, width // 2, num_channels_latents, 2, 2 + ) + noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5) + noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width) + return noise_pred diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py index aee752723..fdd76051f 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -526,6 +526,44 @@ def load_lora(self, file: str): combined_dict = new_dict return combined_dict + + def generate_single_image( + self, + pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # reactivate progress bar since this is slooooow + pipeline.set_progress_bar_config(disable=False) + # todo, figure out how to do video + output = pipeline( + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + num_frames=gen_config.num_frames, + generator=generator, + return_dict=False, + output_type="pil", + **extra + )[0] + + # shape = [1, frames, channels, height, width] + batch_item = output[0] # list of pil images + if gen_config.num_frames > 1: + return batch_item # return the frames. + else: + # get just the first image + img = batch_item[0] + return img def get_model_to_train(self): # todo, loras wont load right unless they have the transformer_1 or transformer_2 in the key. diff --git a/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py b/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py index f434327fd..8d7db2420 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py @@ -73,6 +73,14 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + if num_frames % self.vae_scale_factor_temporal != 1: + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + + width = width // (self.vae.config.scale_factor_spatial * 2) * (self.vae.config.scale_factor_spatial * 2) + height = height // (self.vae.config.scale_factor_spatial * 2) * (self.vae.config.scale_factor_spatial * 2) # unload vae and transformer vae_device = self.vae.device diff --git a/extensions_built_in/sd_trainer/DiffusionTrainer.py b/extensions_built_in/sd_trainer/DiffusionTrainer.py new file mode 100644 index 000000000..730a1aee0 --- /dev/null +++ b/extensions_built_in/sd_trainer/DiffusionTrainer.py @@ -0,0 +1,297 @@ +from collections import OrderedDict +import os +import sqlite3 +import asyncio +import concurrent.futures +from extensions_built_in.sd_trainer.SDTrainer import SDTrainer +from typing import Literal, Optional +import threading +import time +import signal + +AITK_Status = Literal["running", "stopped", "error", "completed"] + + +class DiffusionTrainer(SDTrainer): + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super(DiffusionTrainer, self).__init__(process_id, job, config, **kwargs) + self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db") + self.job_id = os.environ.get("AITK_JOB_ID", None) + self.job_id = self.job_id.strip() if self.job_id is not None else None + self.is_ui_trainer = True + if not os.path.exists(self.sqlite_db_path): + self.is_ui_trainer = False + else: + print(f"Using SQLite database at {self.sqlite_db_path}") + if self.job_id is None: + self.is_ui_trainer = False + else: + print(f"Job ID: \"{self.job_id}\"") + + if self.is_ui_trainer: + self.is_stopping = False + # Create a thread pool for database operations + self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + # Track all async tasks + self._async_tasks = [] + # Initialize the status + self._run_async_operation(self._update_status("running", "Starting")) + self._stop_watcher_started = False + # self.start_stop_watcher(interval_sec=2.0) + + def start_stop_watcher(self, interval_sec: float = 5.0): + """ + Start a daemon thread that periodically checks should_stop() + and terminates the process immediately when triggered. + """ + if not self.is_ui_trainer: + return + if getattr(self, "_stop_watcher_started", False): + return + self._stop_watcher_started = True + t = threading.Thread( + target=self._stop_watcher_thread, args=(interval_sec,), daemon=True + ) + t.start() + + def _stop_watcher_thread(self, interval_sec: float): + while True: + try: + if self.should_stop(): + # Mark and update status (non-blocking; uses existing infra) + self.is_stopping = True + self._run_async_operation( + self._update_status("stopped", "Job stopped (remote)") + ) + # Best-effort flush pending async ops + try: + asyncio.run(self.wait_for_all_async()) + except RuntimeError: + pass + # Try to stop DB thread pool quickly + try: + self.thread_pool.shutdown(wait=False, cancel_futures=True) + except TypeError: + self.thread_pool.shutdown(wait=False) + print("") + print("****************************************************") + print(" Stop signal received; terminating process. ") + print("****************************************************") + os.kill(os.getpid(), signal.SIGINT) + time.sleep(interval_sec) + except Exception: + time.sleep(interval_sec) + + def _run_async_operation(self, coro): + """Helper method to run an async coroutine and track the task.""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + # No event loop exists, create a new one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Create a task and track it + if loop.is_running(): + task = asyncio.run_coroutine_threadsafe(coro, loop) + self._async_tasks.append(asyncio.wrap_future(task)) + else: + task = loop.create_task(coro) + self._async_tasks.append(task) + loop.run_until_complete(task) + + async def _execute_db_operation(self, operation_func): + """Execute a database operation in a separate thread to avoid blocking.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self.thread_pool, operation_func) + + def _db_connect(self): + """Create a new connection for each operation to avoid locking.""" + conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0) + conn.isolation_level = None # Enable autocommit mode + return conn + + def should_stop(self): + if not self.is_ui_trainer: + return False + def _check_stop(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT stop FROM Job WHERE id = ?", (self.job_id,)) + stop = cursor.fetchone() + return False if stop is None else stop[0] == 1 + + return _check_stop() + + def maybe_stop(self): + if not self.is_ui_trainer: + return + if self.should_stop(): + self._run_async_operation( + self._update_status("stopped", "Job stopped")) + self.is_stopping = True + raise Exception("Job stopped") + + async def _update_key(self, key, value): + if not self.accelerator.is_main_process: + return + + def _do_update(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute("BEGIN IMMEDIATE") + try: + # Convert the value to string if it's not already + if isinstance(value, str): + value_to_insert = value + else: + value_to_insert = str(value) + + # Use parameterized query for both the column name and value + update_query = f"UPDATE Job SET {key} = ? WHERE id = ?" + cursor.execute( + update_query, (value_to_insert, self.job_id)) + finally: + cursor.execute("COMMIT") + + await self._execute_db_operation(_do_update) + + def update_step(self): + """Non-blocking update of the step count.""" + if self.accelerator.is_main_process and self.is_ui_trainer: + self._run_async_operation(self._update_key("step", self.step_num)) + + def update_db_key(self, key, value): + """Non-blocking update a key in the database.""" + if self.accelerator.is_main_process and self.is_ui_trainer: + self._run_async_operation(self._update_key(key, value)) + + async def _update_status(self, status: AITK_Status, info: Optional[str] = None): + if not self.accelerator.is_main_process or not self.is_ui_trainer: + return + + def _do_update(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute("BEGIN IMMEDIATE") + try: + if info is not None: + cursor.execute( + "UPDATE Job SET status = ?, info = ? WHERE id = ?", + (status, info, self.job_id) + ) + else: + cursor.execute( + "UPDATE Job SET status = ? WHERE id = ?", + (status, self.job_id) + ) + finally: + cursor.execute("COMMIT") + + await self._execute_db_operation(_do_update) + + def update_status(self, status: AITK_Status, info: Optional[str] = None): + """Non-blocking update of status.""" + if self.accelerator.is_main_process and self.is_ui_trainer: + self._run_async_operation(self._update_status(status, info)) + + async def wait_for_all_async(self): + """Wait for all tracked async operations to complete.""" + if not self._async_tasks: + return + + try: + await asyncio.gather(*self._async_tasks) + except Exception as e: + pass + finally: + # Clear the task list after completion + self._async_tasks.clear() + + def on_error(self, e: Exception): + super(DiffusionTrainer, self).on_error(e) + if self.is_ui_trainer: + if self.accelerator.is_main_process and not self.is_stopping: + self.update_status("error", str(e)) + self.update_db_key("step", self.last_save_step) + asyncio.run(self.wait_for_all_async()) + self.thread_pool.shutdown(wait=True) + + def handle_timing_print_hook(self, timing_dict): + if "train_loop" not in timing_dict: + print("train_loop not found in timing_dict", timing_dict) + return + seconds_per_iter = timing_dict["train_loop"] + # determine iter/sec or sec/iter + if seconds_per_iter < 1: + iters_per_sec = 1 / seconds_per_iter + self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec") + else: + self.update_db_key( + "speed_string", f"{seconds_per_iter:.2f} sec/iter") + + def done_hook(self): + super(DiffusionTrainer, self).done_hook() + if self.is_ui_trainer: + self.update_status("completed", "Training completed") + # Wait for all async operations to finish before shutting down + asyncio.run(self.wait_for_all_async()) + self.thread_pool.shutdown(wait=True) + + def end_step_hook(self): + super(DiffusionTrainer, self).end_step_hook() + if self.is_ui_trainer: + self.update_step() + self.maybe_stop() + + def hook_before_model_load(self): + super().hook_before_model_load() + if self.is_ui_trainer: + self.maybe_stop() + self.update_status("running", "Loading model") + + def before_dataset_load(self): + super().before_dataset_load() + if self.is_ui_trainer: + self.maybe_stop() + self.update_status("running", "Loading dataset") + + def hook_before_train_loop(self): + super().hook_before_train_loop() + if self.is_ui_trainer: + self.maybe_stop() + self.update_step() + self.update_status("running", "Training") + self.timer.add_after_print_hook(self.handle_timing_print_hook) + + def status_update_hook_func(self, string): + self.update_status("running", string) + + def hook_after_sd_init_before_load(self): + super().hook_after_sd_init_before_load() + if self.is_ui_trainer: + self.maybe_stop() + self.sd.add_status_update_hook(self.status_update_hook_func) + + def sample_step_hook(self, img_num, total_imgs): + super().sample_step_hook(img_num, total_imgs) + if self.is_ui_trainer: + self.maybe_stop() + self.update_status( + "running", f"Generating images - {img_num + 1}/{total_imgs}") + + def sample(self, step=None, is_first=False): + self.maybe_stop() + total_imgs = len(self.sample_config.prompts) + self.update_status("running", f"Generating images - 0/{total_imgs}") + super().sample(step, is_first) + self.maybe_stop() + self.update_status("running", "Training") + + def save(self, step=None): + self.maybe_stop() + self.update_status("running", "Saving model") + super().save(step) + self.maybe_stop() + self.update_status("running", "Training") diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index e8b6608f4..86c198ea4 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -34,7 +34,7 @@ import math from toolkit.train_tools import precondition_model_outputs_flow_match from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe -from toolkit.util.wavelet_loss import wavelet_loss +from toolkit.util.losses import wavelet_loss, stepped_loss import torch.nn.functional as F from toolkit.unloader import unload_text_encoder from PIL import Image @@ -129,17 +129,64 @@ def cache_sample_prompts(self): prompt=prompt, # it will autoparse the prompt negative_prompt=sample_item.neg, output_path=output_path, - ctrl_img=sample_item.ctrl_img + ctrl_img=sample_item.ctrl_img, + ctrl_img_1=sample_item.ctrl_img_1, + ctrl_img_2=sample_item.ctrl_img_2, + ctrl_img_3=sample_item.ctrl_img_3, ) + + has_control_images = False + if gen_img_config.ctrl_img is not None or gen_img_config.ctrl_img_1 is not None or gen_img_config.ctrl_img_2 is not None or gen_img_config.ctrl_img_3 is not None: + has_control_images = True # see if we need to encode the control images - if self.sd.encode_control_in_text_embeddings and gen_img_config.ctrl_img is not None: - ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img = ( - TF.to_tensor(ctrl_img) - .unsqueeze(0) - .to(self.sd.device_torch, dtype=self.sd.torch_dtype) - ) + if self.sd.encode_control_in_text_embeddings and has_control_images: + + ctrl_img_list = [] + + if gen_img_config.ctrl_img is not None: + ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img = ( + TF.to_tensor(ctrl_img) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img) + + if gen_img_config.ctrl_img_1 is not None: + ctrl_img_1 = Image.open(gen_img_config.ctrl_img_1).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_1 = ( + TF.to_tensor(ctrl_img_1) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_1) + if gen_img_config.ctrl_img_2 is not None: + ctrl_img_2 = Image.open(gen_img_config.ctrl_img_2).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_2 = ( + TF.to_tensor(ctrl_img_2) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_2) + if gen_img_config.ctrl_img_3 is not None: + ctrl_img_3 = Image.open(gen_img_config.ctrl_img_3).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_3 = ( + TF.to_tensor(ctrl_img_3) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_3) + + if self.sd.has_multiple_control_images: + ctrl_img = ctrl_img_list + else: + ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None + + positive = self.sd.encode_prompt( gen_img_config.prompt, control_images=ctrl_img @@ -202,6 +249,9 @@ def hook_before_train_loop(self): if self.sd.encode_control_in_text_embeddings: # just do a blank image for unconditionals control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.has_multiple_control_images: + control_image = [control_image] + kwargs['control_images'] = control_image self.unconditional_embeds = self.sd.encode_prompt( [self.train_config.unconditional_prompt], @@ -272,6 +322,8 @@ def hook_before_train_loop(self): if self.sd.encode_control_in_text_embeddings: # just do a blank image for unconditionals control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) + if self.sd.has_multiple_control_images: + control_image = [control_image] encode_kwargs['control_images'] = control_image self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs) if self.trigger_word is not None: @@ -305,8 +357,11 @@ def hook_before_train_loop(self): # enable gradient checkpointing on the vae if vae is not None and self.train_config.gradient_checkpointing: - vae.enable_gradient_checkpointing() - vae.train() + try: + vae.enable_gradient_checkpointing() + vae.train() + except: + pass def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch): @@ -574,7 +629,7 @@ def calculate_loss( dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean") additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0 - elif self.dfe.version == 3 or self.dfe.version == 4: + elif self.dfe.version in [3, 4, 5]: dfe_loss = self.dfe( noise=noise, noise_pred=noise_pred, @@ -676,6 +731,10 @@ def calculate_loss( loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") elif self.train_config.loss_type == "wavelet": loss = wavelet_loss(pred, batch.latents, noise) + elif self.train_config.loss_type == "stepped": + loss = stepped_loss(pred, batch.latents, noise, noisy_latents, timesteps, self.sd.noise_scheduler) + # the way this loss works, it is low, increase it to match predictable LR effects + loss = loss * 10.0 else: loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") @@ -1151,6 +1210,11 @@ def train_single_accumulation(self, batch: DataLoaderBatchDTO): has_clip_image_embeds = True # we are caching embeds, handle that differently has_clip_image = False + # do prior pred if prior regularization batch + do_reg_prior = False + if any([batch.file_items[idx].prior_reg for idx in range(len(batch.file_items))]): + do_reg_prior = True + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: raise ValueError( "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") @@ -1645,11 +1709,6 @@ def get_adapter_multiplier(): prior_pred = None - do_reg_prior = False - # if is_reg and (self.network is not None or self.adapter is not None): - # # we are doing a reg image and we have a network or adapter - # do_reg_prior = True - do_inverted_masked_prior = False if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: do_inverted_masked_prior = True diff --git a/extensions_built_in/sd_trainer/UITrainer.py b/extensions_built_in/sd_trainer/UITrainer.py index df940ceda..96a162e40 100644 --- a/extensions_built_in/sd_trainer/UITrainer.py +++ b/extensions_built_in/sd_trainer/UITrainer.py @@ -5,7 +5,9 @@ import concurrent.futures from extensions_built_in.sd_trainer.SDTrainer import SDTrainer from typing import Literal, Optional - +import threading +import time +import signal AITK_Status = Literal["running", "stopped", "error", "completed"] @@ -30,6 +32,49 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): self._async_tasks = [] # Initialize the status self._run_async_operation(self._update_status("running", "Starting")) + self._stop_watcher_started = False + # self.start_stop_watcher(interval_sec=2.0) + + def start_stop_watcher(self, interval_sec: float = 5.0): + """ + Start a daemon thread that periodically checks should_stop() + and terminates the process immediately when triggered. + """ + if getattr(self, "_stop_watcher_started", False): + return + self._stop_watcher_started = True + t = threading.Thread( + target=self._stop_watcher_thread, args=(interval_sec,), daemon=True + ) + t.start() + + def _stop_watcher_thread(self, interval_sec: float): + while True: + try: + if self.should_stop(): + # Mark and update status (non-blocking; uses existing infra) + self.is_stopping = True + self._run_async_operation( + self._update_status("stopped", "Job stopped (remote)") + ) + # Best-effort flush pending async ops + try: + asyncio.run(self.wait_for_all_async()) + except RuntimeError: + pass + # Try to stop DB thread pool quickly + try: + self.thread_pool.shutdown(wait=False, cancel_futures=True) + except TypeError: + self.thread_pool.shutdown(wait=False) + print("") + print("****************************************************") + print(" Stop signal received; terminating process. ") + print("****************************************************") + os.kill(os.getpid(), signal.SIGINT) + time.sleep(interval_sec) + except Exception: + time.sleep(interval_sec) def _run_async_operation(self, coro): """Helper method to run an async coroutine and track the task.""" diff --git a/extensions_built_in/sd_trainer/__init__.py b/extensions_built_in/sd_trainer/__init__.py index 47c84fa17..065ff8185 100644 --- a/extensions_built_in/sd_trainer/__init__.py +++ b/extensions_built_in/sd_trainer/__init__.py @@ -16,8 +16,10 @@ class SDTrainerExtension(Extension): def get_process(cls): # import your process class here so it is only loaded when needed and return it from .SDTrainer import SDTrainer + return SDTrainer + # This is for generic training (LoRA, Dreambooth, FineTuning) class UITrainerExtension(Extension): # uid must be unique, it is how the extension is identified @@ -32,9 +34,28 @@ class UITrainerExtension(Extension): def get_process(cls): # import your process class here so it is only loaded when needed and return it from .UITrainer import UITrainer + return UITrainer +# This is a universal trainer that can be from ui or api +class DiffusionTrainerExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "diffusion_trainer" + + # name is the name of the extension for printing + name = "Diffusion Trainer" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .DiffusionTrainer import DiffusionTrainer + + return DiffusionTrainer + + # for backwards compatability class TextualInversionTrainer(SDTrainerExtension): uid = "textual_inversion_trainer" @@ -42,5 +63,8 @@ class TextualInversionTrainer(SDTrainerExtension): AI_TOOLKIT_EXTENSIONS = [ # you can put a list of extensions here - SDTrainerExtension, TextualInversionTrainer, UITrainerExtension + SDTrainerExtension, + TextualInversionTrainer, + UITrainerExtension, + DiffusionTrainerExtension, ] diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 4290e966b..115b6aff2 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -173,6 +173,9 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No if raw_datasets is not None and len(raw_datasets) > 0: for raw_dataset in raw_datasets: dataset = DatasetConfig(**raw_dataset) + # handle trigger word per dataset + if dataset.trigger_word is None and self.trigger_word is not None: + dataset.trigger_word = self.trigger_word is_caching = dataset.cache_latents or dataset.cache_latents_to_disk if not is_caching: self.is_latents_cached = False @@ -280,6 +283,7 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No self.current_boundary_index = 0 self.steps_this_boundary = 0 + self.num_consecutive_oom = 0 def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): # override in subclass @@ -362,6 +366,9 @@ def sample(self, step=None, is_first=False): fps=sample_item.fps, ctrl_img=sample_item.ctrl_img, ctrl_idx=sample_item.ctrl_idx, + ctrl_img_1=sample_item.ctrl_img_1, + ctrl_img_2=sample_item.ctrl_img_2, + ctrl_img_3=sample_item.ctrl_img_3, **extra_args )) @@ -1612,6 +1619,19 @@ def run(self): # if it has it if hasattr(te, 'enable_xformers_memory_efficient_attention'): te.enable_xformers_memory_efficient_attention() + + if self.train_config.attention_backend != 'native': + if hasattr(vae, 'set_attention_backend'): + vae.set_attention_backend(self.train_config.attention_backend) + if hasattr(unet, 'set_attention_backend'): + unet.set_attention_backend(self.train_config.attention_backend) + if isinstance(text_encoder, list): + for te in text_encoder: + if hasattr(te, 'set_attention_backend'): + te.set_attention_backend(self.train_config.attention_backend) + else: + if hasattr(text_encoder, 'set_attention_backend'): + text_encoder.set_attention_backend(self.train_config.attention_backend) if self.train_config.sdp: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) @@ -1949,15 +1969,23 @@ def run(self): for group in optimizer.param_groups: previous_lrs.append(group['lr']) - try: - print_acc(f"Loading optimizer state from {optimizer_state_file_path}") - optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True) - optimizer.load_state_dict(optimizer_state_dict) - del optimizer_state_dict - flush() - except Exception as e: - print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}") - print_acc(e) + load_optimizer = True + if self.network is not None: + if self.network.did_change_weights: + # do not load optimizer if the network changed, it will result in + # a double state that will oom. + load_optimizer = False + + if load_optimizer: + try: + print_acc(f"Loading optimizer state from {optimizer_state_file_path}") + optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True) + optimizer.load_state_dict(optimizer_state_dict) + del optimizer_state_dict + flush() + except Exception as e: + print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}") + print_acc(e) # update the optimizer LR from the params print_acc(f"Updating optimizer LR from params") @@ -2204,17 +2232,31 @@ def run(self): ### HOOK ### if self.torch_profiler is not None: self.torch_profiler.start() - with self.accelerator.accumulate(self.modules_being_trained): - try: + did_oom = False + try: + with self.accelerator.accumulate(self.modules_being_trained): loss_dict = self.hook_train_loop(batch_list) - except Exception as e: - traceback.print_exc() - #print batch info - print("Batch Items:") - for batch in batch_list: - for item in batch.file_items: - print(f" - {item.path}") - raise e + except torch.cuda.OutOfMemoryError: + did_oom = True + except RuntimeError as e: + if "CUDA out of memory" in str(e): + did_oom = True + else: + raise # not an OOM; surface real errors + if did_oom: + self.num_consecutive_oom += 1 + if self.num_consecutive_oom > 3: + raise RuntimeError("OOM during training step 3 times in a row, aborting training") + optimizer.zero_grad(set_to_none=True) + flush() + torch.cuda.ipc_collect() + # skip this step and keep going + print_acc("") + print_acc("################################################") + print_acc(f"# OOM during training step, skipping batch {self.num_consecutive_oom}/3 #") + print_acc("################################################") + print_acc("") + self.num_consecutive_oom = 0 if self.torch_profiler is not None: torch.cuda.synchronize() # Make sure all CUDA ops are done self.torch_profiler.stop() @@ -2262,6 +2304,22 @@ def run(self): if self.step_num != self.start_step: if is_sample_step or is_save_step: self.accelerator.wait_for_everyone() + + if is_save_step: + self.accelerator + # print above the progress bar + if self.progress_bar is not None: + self.progress_bar.pause() + print_acc(f"\nSaving at step {self.step_num}") + self.save(self.step_num) + self.ensure_params_requires_grad() + # clear any grads + optimizer.zero_grad() + flush() + flush_next = True + if self.progress_bar is not None: + self.progress_bar.unpause() + if is_sample_step: if self.progress_bar is not None: self.progress_bar.pause() diff --git a/requirements.txt b/requirements.txt index ecc672d1b..a0700ab4a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torchao==0.10.0 safetensors git+https://github.com/jaretburkett/easy_dwpose.git -git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63 +git+https://github.com/huggingface/diffusers@1448b035859dd57bbb565239dcdd79a025a85422 transformers==4.52.4 lycoris-lora==1.8.3 flatten_json diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 1b295cbd7..bf97a815f 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -4,6 +4,7 @@ import random import torch +import torchaudio from toolkit.prompt_utils import PromptEmbeds @@ -65,7 +66,19 @@ def __init__( self.num_frames: int = kwargs.get('num_frames', sample_config.num_frames) self.ctrl_img: Optional[str] = kwargs.get('ctrl_img', None) self.ctrl_idx: int = kwargs.get('ctrl_idx', 0) + # for multi control image models + self.ctrl_img_1: Optional[str] = kwargs.get('ctrl_img_1', self.ctrl_img) + self.ctrl_img_2: Optional[str] = kwargs.get('ctrl_img_2', None) + self.ctrl_img_3: Optional[str] = kwargs.get('ctrl_img_3', None) + self.network_multiplier: float = kwargs.get('network_multiplier', sample_config.network_multiplier) + # convert to a number if it is a string + if isinstance(self.network_multiplier, str): + try: + self.network_multiplier = float(self.network_multiplier) + except: + print(f"Invalid network_multiplier {self.network_multiplier}, defaulting to 1.0") + self.network_multiplier = 1.0 class SampleConfig: @@ -166,6 +179,9 @@ def __init__(self, **kwargs): elif linear is not None: self.rank: int = linear self.linear: int = linear + else: + self.rank: int = 4 + self.linear: int = 4 self.conv: int = kwargs.get('conv', None) self.alpha: float = kwargs.get('alpha', 1.0) self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha) @@ -351,6 +367,8 @@ def __init__(self, **kwargs): self.dtype: str = kwargs.get('dtype', 'fp32') self.xformers = kwargs.get('xformers', False) self.sdp = kwargs.get('sdp', False) + # see https://huggingface.co/docs/diffusers/main/optimization/attention_backends#available-backends for options + self.attention_backend: str = kwargs.get('attention_backend', 'native') # native, flash, _flash_3_hub, _flash_3, self.train_unet = kwargs.get('train_unet', True) self.train_text_encoder = kwargs.get('train_text_encoder', False) self.train_refiner = kwargs.get('train_refiner', True) @@ -812,6 +830,7 @@ def __init__(self, **kwargs): self.buckets: bool = kwargs.get('buckets', True) self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64) self.is_reg: bool = kwargs.get('is_reg', False) + self.prior_reg: bool = kwargs.get('prior_reg', False) self.network_weight: float = float(kwargs.get('network_weight', 1.0)) self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0)) self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False) @@ -823,11 +842,29 @@ def __init__(self, **kwargs): self.control_path: Union[str,List[str]] = kwargs.get('control_path', None) # depth maps, etc if self.control_path == '': self.control_path = None + + # handle multi control inputs from the ui. It is just easier to handle it here for a cleaner ui experience + control_path_1 = kwargs.get('control_path_1', None) + control_path_2 = kwargs.get('control_path_2', None) + control_path_3 = kwargs.get('control_path_3', None) + + if any([control_path_1, control_path_2, control_path_3]): + control_paths = [] + if control_path_1: + control_paths.append(control_path_1) + if control_path_2: + control_paths.append(control_path_2) + if control_path_3: + control_paths.append(control_path_3) + self.control_path = control_paths + + # color for transparent reigon of control images with transparency + self.control_transparent_color: List[int] = kwargs.get('control_transparent_color', [0, 0, 0]) # inpaint images should be webp/png images with alpha channel. The alpha 0 (invisible) section will # be the part conditioned to be inpainted. The alpha 1 (visible) section will be the part that is ignored self.inpaint_path: Union[str,List[str]] = kwargs.get('inpaint_path', None) # instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters) - self.full_size_control_images: bool = kwargs.get('full_size_control_images', False) + self.full_size_control_images: bool = kwargs.get('full_size_control_images', True) self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black) @@ -961,6 +998,9 @@ def __init__( extra_values: List[float] = None, # extra values to save with prompt file logger: Optional[EmptyLogger] = None, ctrl_img: Optional[str] = None, # control image for controlnet + ctrl_img_1: Optional[str] = None, # first control image for multi control model + ctrl_img_2: Optional[str] = None, # second control image for multi control model + ctrl_img_3: Optional[str] = None, # third control image for multi control model num_frames: int = 1, fps: int = 15, ctrl_idx: int = 0 @@ -997,6 +1037,12 @@ def __init__( self.ctrl_img = ctrl_img self.ctrl_idx = ctrl_idx + if ctrl_img_1 is None and ctrl_img is not None: + ctrl_img_1 = ctrl_img + + self.ctrl_img_1 = ctrl_img_1 + self.ctrl_img_2 = ctrl_img_2 + self.ctrl_img_3 = ctrl_img_3 # prompt string will override any settings above self._process_prompt_string() @@ -1083,6 +1129,15 @@ def save_image(self, image, count: int = 0, max_count=0): ) else: raise ValueError(f"Unsupported video format {self.output_ext}") + elif self.output_ext in ['wav', 'mp3']: + # save audio file + torchaudio.save( + self.get_image_path(count, max_count), + image[0].to('cpu'), + sample_rate=48000, + format=None, + backend=None + ) else: # TODO save image gen header info for A1111 and us, our seeds probably wont match image.save(self.get_image_path(count, max_count)) diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index b68636639..ca27dd47e 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -121,6 +121,7 @@ def __init__(self, *args, **kwargs): self.network_weight: float = self.dataset_config.network_weight self.is_reg = self.dataset_config.is_reg + self.prior_reg = self.dataset_config.prior_reg self.tensor: Union[torch.Tensor, None] = None def cleanup(self): @@ -143,6 +144,7 @@ def __init__(self, **kwargs): self.tensor: Union[torch.Tensor, None] = None self.latents: Union[torch.Tensor, None] = None self.control_tensor: Union[torch.Tensor, None] = None + self.control_tensor_list: Union[List[List[torch.Tensor]], None] = None self.clip_image_tensor: Union[torch.Tensor, None] = None self.mask_tensor: Union[torch.Tensor, None] = None self.unaugmented_tensor: Union[torch.Tensor, None] = None @@ -159,7 +161,6 @@ def __init__(self, **kwargs): self.latents: Union[torch.Tensor, None] = None if is_latents_cached: self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) - self.control_tensor: Union[torch.Tensor, None] = None self.prompt_embeds: Union[PromptEmbeds, None] = None # if self.file_items[0].control_tensor is not None: # if any have a control tensor, we concatenate them @@ -177,6 +178,16 @@ def __init__(self, **kwargs): else: control_tensors.append(x.control_tensor) self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors]) + + # handle control tensor list + if any([x.control_tensor_list is not None for x in self.file_items]): + self.control_tensor_list = [] + for x in self.file_items: + if x.control_tensor_list is not None: + self.control_tensor_list.append(x.control_tensor_list) + else: + raise Exception(f"Could not find control tensors for all file items, missing for {x.path}") + self.inpaint_tensor: Union[torch.Tensor, None] = None if any([x.inpaint_tensor is not None for x in self.file_items]): diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 69df9cbd7..3fe105924 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -301,6 +301,7 @@ def __init__(self: 'FileItemDTO', *args, **kwargs): dataset_config: DatasetConfig = kwargs.get('dataset_config', None) self.extra_values: List[float] = dataset_config.extra_values + self.trigger_word = dataset_config.trigger_word # todo allow for loading from sd-scripts style dict def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None): @@ -343,6 +344,9 @@ def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None): prompt = clean_caption(prompt) if short_caption is not None: short_caption = clean_caption(short_caption) + + if prompt.strip() == '' and self.dataset_config.default_caption is not None: + prompt = self.dataset_config.default_caption else: prompt = '' if self.dataset_config.default_caption is not None: @@ -364,6 +368,13 @@ def get_caption( add_if_not_present=False, short_caption=False ): + if trigger is None and self.trigger_word is not None: + trigger = self.trigger_word + + if trigger is not None and not self.is_reg: + # add if not present if not regularization + add_if_not_present = True + if short_caption: raw_caption = self.raw_caption_short else: @@ -371,7 +382,7 @@ def get_caption( if raw_caption is None: raw_caption = '' # handle dropout - if self.dataset_config.caption_dropout_rate > 0 and not short_caption: + if self.dataset_config.caption_dropout_rate > 0 and not short_caption and not self.dataset_config.cache_text_embeddings: # get a random float form 0 to 1 rand = random.random() if rand < self.dataset_config.caption_dropout_rate: @@ -386,7 +397,7 @@ def get_caption( token_list = [x for x in token_list if x] # handle token dropout - if self.dataset_config.token_dropout_rate > 0 and not short_caption: + if self.dataset_config.token_dropout_rate > 0 and not short_caption and not self.dataset_config.cache_text_embeddings: new_token_list = [] keep_tokens: int = self.dataset_config.keep_tokens for idx, token in enumerate(token_list): @@ -408,7 +419,7 @@ def get_caption( # join back together caption = ', '.join(token_list) - # caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present) + caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present) if self.dataset_config.random_triggers: num_triggers = self.dataset_config.random_triggers_max @@ -433,7 +444,8 @@ def get_caption( token_list = [x for x in token_list if x] random.shuffle(token_list) caption = ', '.join(token_list) - + if caption == '': + pass return caption @@ -842,6 +854,9 @@ def __init__(self: 'FileItemDTO', *args, **kwargs): self.has_control_image = False self.control_path: Union[str, List[str], None] = None self.control_tensor: Union[torch.Tensor, None] = None + self.control_tensor_list: Union[List[torch.Tensor], None] = None + sd = kwargs.get('sd', None) + self.use_raw_control_images = sd is not None and sd.use_raw_control_images dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) self.full_size_control_images = False if dataset_config.control_path is not None: @@ -876,28 +891,30 @@ def load_control_image(self: 'FileItemDTO'): for control_path in control_path_list: try: - img = Image.open(control_path).convert('RGB') + img = Image.open(control_path) img = exif_transpose(img) + + if img.mode in ("RGBA", "LA"): + # Create a background with the specified transparent color + transparent_color = tuple(self.dataset_config.control_transparent_color) + background = Image.new("RGB", img.size, transparent_color) + # Paste the image on top using its alpha channel as mask + background.paste(img, mask=img.getchannel("A")) + img = background + else: + # Already no alpha channel + img = img.convert("RGB") except Exception as e: print_acc(f"Error: {e}") print_acc(f"Error loading image: {control_path}") - + if not self.full_size_control_images: # we just scale them to 512x512: w, h = img.size img = img.resize((512, 512), Image.BICUBIC) - else: + elif not self.use_raw_control_images: w, h = img.size - if w > h and self.scale_to_width < self.scale_to_height: - # throw error, they should match - raise ValueError( - f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") - elif h > w and self.scale_to_height < self.scale_to_width: - # throw error, they should match - raise ValueError( - f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") - if self.flip_x: # do a flip img = img.transpose(Image.FLIP_LEFT_RIGHT) @@ -931,11 +948,15 @@ def load_control_image(self: 'FileItemDTO'): self.control_tensor = None elif len(control_tensors) == 1: self.control_tensor = control_tensors[0] + elif self.use_raw_control_images: + # just send the list of tensors as their shapes wont match + self.control_tensor_list = control_tensors else: self.control_tensor = torch.stack(control_tensors, dim=0) def cleanup_control(self: 'FileItemDTO'): self.control_tensor = None + self.control_tensor_list = None class ClipImageFileItemDTOMixin: @@ -1795,9 +1816,6 @@ def get_text_embedding_info_dict(self: 'FileItemDTO'): # TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible. if self.caption is None: self.load_caption() - # throw error is [trigger] in caption as we cannot inject it while caching - if '[trigger]' in self.caption: - raise Exception("Error: [trigger] in caption is not supported when caching text embeddings. Please remove it from the caption.") item = OrderedDict([ ("caption", self.caption), ("text_embedding_space_version", self.text_embedding_space_version), @@ -1868,14 +1886,31 @@ def cache_text_embeddings(self: 'AiToolkitDataset'): if file_item.encode_control_in_text_embeddings: if file_item.control_path is None: raise Exception(f"Could not find a control image for {file_item.path} which is needed for this model") - # load the control image and feed it into the text encoder - ctrl_img = Image.open(file_item.control_path).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img = ( - TF.to_tensor(ctrl_img) - .unsqueeze(0) - .to(self.sd.device_torch, dtype=self.sd.torch_dtype) - ) + ctrl_img_list = [] + control_path_list = file_item.control_path + if not isinstance(file_item.control_path, list): + control_path_list = [control_path_list] + for i in range(len(control_path_list)): + try: + img = Image.open(control_path_list[i]).convert("RGB") + img = exif_transpose(img) + # convert to 0 to 1 tensor + img = ( + TF.to_tensor(img) + .unsqueeze(0) + .to(self.sd.device_torch, dtype=self.sd.torch_dtype) + ) + ctrl_img_list.append(img) + except Exception as e: + print_acc(f"Error: {e}") + print_acc(f"Error loading control image: {control_path_list[i]}") + + if len(ctrl_img_list) == 0: + ctrl_img = None + elif not self.sd.has_multiple_control_images: + ctrl_img = ctrl_img_list[0] + else: + ctrl_img = ctrl_img_list prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption, control_images=ctrl_img) else: prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index cd4437356..2c4c11a83 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -38,6 +38,10 @@ 'QConv2d', ] +class IdentityModule(torch.nn.Module): + def forward(self, x): + return x + class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. @@ -81,16 +85,25 @@ def __init__( # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # else: self.lora_dim = lora_dim + self.full_rank = network.network_type.lower() == "fullrank" if org_module.__class__.__name__ in CONV_MODULES: kernel_size = org_module.kernel_size stride = org_module.stride padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias) + if self.full_rank: + self.lora_down = torch.nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False) + self.lora_up = IdentityModule() + else: + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias) else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias) + if self.full_rank: + self.lora_down = torch.nn.Linear(in_dim, out_dim, bias=False) + self.lora_up = IdentityModule() + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error @@ -100,7 +113,8 @@ def __init__( # same as microsoft's torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_up.weight) + if not self.full_rank: + torch.nn.init.zeros_(self.lora_up.weight) self.multiplier: Union[float, List[float]] = multiplier # wrap the original module so it doesn't get weights updated @@ -232,6 +246,7 @@ def __init__( self.is_lumina2 = is_lumina2 self.network_type = network_type self.is_assistant_adapter = is_assistant_adapter + self.full_rank = network_type.lower() == "fullrank" if self.network_type.lower() == "dora": self.module_class = DoRAModule module_class = DoRAModule @@ -426,7 +441,10 @@ def create_modules( except: pass else: - lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)] + if self.full_rank: + lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape)] + else: + lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)] return loras, skipped text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] diff --git a/toolkit/models/FakeVAE.py b/toolkit/models/FakeVAE.py new file mode 100644 index 000000000..90e3d507c --- /dev/null +++ b/toolkit/models/FakeVAE.py @@ -0,0 +1,135 @@ +from diffusers import AutoencoderKL +from typing import Optional, Union +import torch +import torch.nn as nn +import numpy as np +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput +from diffusers.models.autoencoders.vae import DecoderOutput + + +class Config: + in_channels = 3 + out_channels = 3 + down_block_types = ("1",) + up_block_types = ("1",) + block_out_channels = (1,) + latent_channels = 3 # usually 4 + norm_num_groups = 1 + sample_size = 512 + scaling_factor = 1.0 + # scaling_factor = 1.8 + shift_factor = 0 + + def __getitem__(cls, x): + return getattr(cls, x) + + +class FakeVAE(nn.Module): + def __init__(self, scaling_factor=1.0): + super().__init__() + self._dtype = torch.float32 + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.config = Config() + self.config.scaling_factor = scaling_factor + + @property + def dtype(self): + return self._dtype + + @dtype.setter + def dtype(self, value): + self._dtype = value + + @property + def device(self): + return self._device + + @device.setter + def device(self, value): + self._device = value + + # mimic to from torch + def to(self, *args, **kwargs): + # pull out dtype and device if they exist + if "dtype" in kwargs: + self._dtype = kwargs["dtype"] + if "device" in kwargs: + self._device = kwargs["device"] + return super().to(*args, **kwargs) + + def enable_xformers_memory_efficient_attention(self): + pass + + # @apply_forward_hook + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> AutoencoderKLOutput: + h = x + + # moments = self.quant_conv(h) + # posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (h,) + + class FakeDist: + def __init__(self, x): + self._sample = x + + def sample(self): + return self._sample + + return AutoencoderKLOutput(latent_dist=FakeDist(h)) + + def _decode( + self, z: torch.FloatTensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + dec = z + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def _set_gradient_checkpointing(self, module, value=False): + pass + + def enable_tiling(self, use_tiling: bool = True): + pass + + def disable_tiling(self): + pass + + def enable_slicing(self): + pass + + def disable_slicing(self): + pass + + def set_use_memory_efficient_attention_xformers(self, value: bool = True): + pass + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + dec = sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index af5ea06fe..58d48b4fe 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -181,6 +181,10 @@ def __init__( # set true for models that encode control image into text embeddings self.encode_control_in_text_embeddings = False + # control images will come in as a list for encoding some things if true + self.has_multiple_control_images = False + # do not resize control images + self.use_raw_control_images = False # properties for old arch for backwards compatibility @property diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 0f7bff09a..65c2de3cd 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -140,13 +140,13 @@ def forward(self, x): class DiffusionFeatureExtractor(nn.Module): - def __init__(self, in_channels=16): + def __init__(self, in_channels=16, out_channels=512): super().__init__() self.version = 1 num_blocks = 6 self.conv_in = nn.Conv2d(in_channels, 512, 1) self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)]) - self.conv_out = nn.Conv2d(512, 512, 1) + self.conv_out = nn.Conv2d(512, out_channels, 1) def forward(self, x): x = self.conv_in(x) @@ -470,10 +470,31 @@ def get_siglip_features(self, tensors_0_1): output_hidden_states=True, ) - # embeds = id_embeds['hidden_states'][-2] # penultimate layer - image_embeds = id_embeds['pooler_output'] - image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + image_embeds = id_embeds['hidden_states'][-2] # penultimate layer + # image_embeds = id_embeds['pooler_output'] + # image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) return image_embeds + + def step_latents(self, noise, noise_pred, noisy_latents, timesteps, scheduler): + bs = noise_pred.shape[0] + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + stepped_chunks = [] + for idx in range(bs): + model_output = noise_pred_chunks[idx] + timestep = timestep_chunks[idx] + scheduler._step_index = None + scheduler._init_step_index(timestep) + sample = noisy_latent_chunks[idx].to(torch.float32) + + sigma = scheduler.sigmas[scheduler.step_index] + sigma_next = scheduler.sigmas[-1] # use last sigma for final step + prev_sample = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(prev_sample) + + stepped_latents = torch.cat(stepped_chunks, dim=0) + return stepped_latents def forward( self, @@ -509,31 +530,12 @@ def forward( if model is not None and hasattr(model, 'get_stepped_pred'): stepped_latents = model.get_stepped_pred(noise_pred, noise) else: - # stepped_latents = noise - noise_pred - # first we step the scheduler from current timestep to the very end for a full denoise - bs = noise_pred.shape[0] - noise_pred_chunks = torch.chunk(noise_pred, bs) - timestep_chunks = torch.chunk(timesteps, bs) - noisy_latent_chunks = torch.chunk(noisy_latents, bs) - stepped_chunks = [] - for idx in range(bs): - model_output = noise_pred_chunks[idx] - timestep = timestep_chunks[idx] - scheduler._step_index = None - scheduler._init_step_index(timestep) - sample = noisy_latent_chunks[idx].to(torch.float32) - - sigma = scheduler.sigmas[scheduler.step_index] - sigma_next = scheduler.sigmas[-1] # use last sigma for final step - prev_sample = sample + (sigma_next - sigma) * model_output - stepped_chunks.append(prev_sample) - - stepped_latents = torch.cat(stepped_chunks, dim=0) + stepped_latents = self.step_latents(noise, noise_pred, noisy_latents, timesteps, scheduler) latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) - - scaling_factor = self.vae.config['scaling_factor'] if 'scaling_factor' in self.vae.config else 1.0 - shift_factor = self.vae.config['shift_factor'] if 'shift_factor' in self.vae.config else 0.0 + + scaling_factor = self.vae.config.scaling_factor if hasattr(self.vae.config, 'scaling_factor') else 1.0 + shift_factor = self.vae.config.shift_factor if hasattr(self.vae.config, 'shift_factor') else 0.0 latents = (latents / scaling_factor) + shift_factor if is_video: # if video, we need to unsqueeze the latents to match the vae input shape @@ -590,6 +592,52 @@ def forward( self.step += 1 return total_loss + +class DiffusionFeatureExtractor5(DiffusionFeatureExtractor4): + def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None): + super().__init__(device=device, dtype=dtype, vae=vae) + self.version = 5 + + def step_latents(self, noise, noise_pred, noisy_latents, timesteps, scheduler, total_steps: int = 1000, eps: float = 1e-6): + bs = noise_pred.shape[0] + + # Chunk inputs per-sample (keeps existing structure) + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + noise_chunks = torch.chunk(noise, bs) + + stepped_chunks = [] + x0_pred_chunks = [] + + for idx in range(bs): + model_output = noise_pred_chunks[idx] # predicted noise (same shape as latent) + timestep = timestep_chunks[idx] # scalar tensor per sample (e.g., [t]) + sample = noisy_latent_chunks[idx].to(torch.float32) + noise_i = noise_chunks[idx].to(sample.dtype).to(sample.device) + + # Initialize scheduler step index for this sample + scheduler._step_index = None + scheduler._init_step_index(timestep) + + # ---- Step +50 indices (or to the end) in sigma-space ---- + sigma = scheduler.sigmas[scheduler.step_index] + target_idx = min(scheduler.step_index + 50, len(scheduler.sigmas) - 1) + sigma_next = scheduler.sigmas[target_idx] + + # One-step update along the model-predicted direction + stepped = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(stepped) + + # ---- Inverse-Gaussian recovery at the target timestep ---- + t_01 = (scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype) + original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01) + x0_pred_chunks.append(original_samples) + + # stepped_latents = torch.cat(stepped_chunks, dim=0) + predicted_images = torch.cat(x0_pred_chunks, dim=0) + # return stepped_latents, predicted_images + return predicted_images def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor: if model_path == "v3": @@ -600,6 +648,10 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor: dfe = DiffusionFeatureExtractor4(vae=vae) dfe.eval() return dfe + if model_path == "v5": + dfe = DiffusionFeatureExtractor5(vae=vae) + dfe.eval() + return dfe if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") # if it ende with safetensors @@ -611,7 +663,9 @@ def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor: state_dict = state_dict['model_state_dict'] if 'conv_in.weight' in state_dict: - dfe = DiffusionFeatureExtractor() + # determine num out channels + out_channels = state_dict['conv_out.weight'].shape[0] + dfe = DiffusionFeatureExtractor(out_channels=out_channels) else: dfe = DiffusionFeatureExtractor2() diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 59c15d3d2..617f4baa9 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -46,6 +46,15 @@ 'percentage' ] +printed_messages = [] + + +def print_once(msg): + global printed_messages + if msg not in printed_messages: + print(msg) + printed_messages.append(msg) + def broadcast_and_multiply(tensor, multiplier): # Determine the number of dimensions required @@ -340,7 +349,10 @@ def merge_in(self: Module, merge_weight=1.0): if not self.can_merge_in: return # get up/down weight - up_weight = self.lora_up.weight.clone().float() + if self.full_rank: + up_weight = None + else: + up_weight = self.lora_up.weight.clone().float() down_weight = self.lora_down.weight.clone().float() # extract weight from org_module @@ -365,7 +377,9 @@ def merge_in(self: Module, merge_weight=1.0): scale = scale * self.scalar # merge weight - if len(weight.size()) == 2: + if self.full_rank: + weight = weight + multiplier * down_weight * scale + elif len(weight.size()) == 2: # linear weight = weight + multiplier * (up_weight @ down_weight) * scale elif down_weight.size()[2:4] == (1, 1): @@ -434,6 +448,8 @@ def __init__( self.module_losses: List[torch.Tensor] = [] self.lorm_train_mode: Literal['local', None] = None self.can_merge_in = not is_lorm + # will prevent optimizer from loading as it will have double states + self.did_change_weights = False def get_keymap(self: Network, force_weight_mapping=False): use_weight_mapping = False @@ -634,6 +650,46 @@ def load_weights(self: Network, file, force_weight_mapping=False): if key not in current_state_dict: extra_dict[key] = load_sd[key] to_delete.append(key) + elif "lora_down" in key or "lora_up" in key: + # handle expanding/shrinking LoRA (linear only) + if len(load_sd[key].shape) == 2: + load_value = load_sd[key] # from checkpoint + blank_val = current_state_dict[key] # shape we need in the target model + tgt_h, tgt_w = blank_val.shape + src_h, src_w = load_value.shape + + if (src_h, src_w) == (tgt_h, tgt_w): + # shapes already match: keep original + pass + + elif "lora_down" in key and src_h < tgt_h: + print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}") + new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype) + new_val[:src_h, :src_w] = load_value # src_w should already match + load_sd[key] = new_val + self.did_change_weights = True + + elif "lora_up" in key and src_w < tgt_w: + print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}") + new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype) + new_val[:src_h, :src_w] = load_value # src_h should already match + load_sd[key] = new_val + self.did_change_weights = True + + elif "lora_down" in key and src_h > tgt_h: + print_once(f"Shrinking {key} from {load_value.shape} to {blank_val.shape}") + load_sd[key] = load_value[:tgt_h, :tgt_w] + self.did_change_weights = True + + elif "lora_up" in key and src_w > tgt_w: + print_once(f"Shrinking {key} from {load_value.shape} to {blank_val.shape}") + load_sd[key] = load_value[:tgt_h, :tgt_w] + self.did_change_weights = True + + else: + # unexpected mismatch (e.g., both dims differ in a way that doesn't match lora_up/down semantics) + raise ValueError(f"Unhandled LoRA shape change for {key}: src={load_value.shape}, tgt={blank_val.shape}") + for key in to_delete: del load_sd[key] diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index b0558b8cb..b8a6f1e56 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -21,7 +21,7 @@ class ACTION_TYPES_SLIDER: class PromptEmbeds: # text_embeds: torch.Tensor # pooled_embeds: Union[torch.Tensor, None] - # attention_mask: Union[torch.Tensor, None] + # attention_mask: Union[torch.Tensor, List[torch.Tensor], None] def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor], attention_mask=None) -> None: if isinstance(args, list) or isinstance(args, tuple): @@ -43,7 +43,10 @@ def to(self, *args, **kwargs): if self.pooled_embeds is not None: self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs) if self.attention_mask is not None: - self.attention_mask = self.attention_mask.to(*args, **kwargs) + if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple): + self.attention_mask = [t.to(*args, **kwargs) for t in self.attention_mask] + else: + self.attention_mask = self.attention_mask.to(*args, **kwargs) return self def detach(self): @@ -55,7 +58,10 @@ def detach(self): if new_embeds.pooled_embeds is not None: new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach() if new_embeds.attention_mask is not None: - new_embeds.attention_mask = new_embeds.attention_mask.detach() + if isinstance(new_embeds.attention_mask, list) or isinstance(new_embeds.attention_mask, tuple): + new_embeds.attention_mask = [t.detach() for t in new_embeds.attention_mask] + else: + new_embeds.attention_mask = new_embeds.attention_mask.detach() return new_embeds def clone(self): @@ -69,7 +75,10 @@ def clone(self): prompt_embeds = PromptEmbeds(cloned_text_embeds) if self.attention_mask is not None: - prompt_embeds.attention_mask = self.attention_mask.clone() + if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple): + prompt_embeds.attention_mask = [t.clone() for t in self.attention_mask] + else: + prompt_embeds.attention_mask = self.attention_mask.clone() return prompt_embeds def expand_to_batch(self, batch_size): @@ -89,7 +98,10 @@ def expand_to_batch(self, batch_size): if pe.pooled_embeds is not None: pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1) if pe.attention_mask is not None: - pe.attention_mask = pe.attention_mask.expand(batch_size, -1) + if isinstance(pe.attention_mask, list) or isinstance(pe.attention_mask, tuple): + pe.attention_mask = [t.expand(batch_size, -1) for t in pe.attention_mask] + else: + pe.attention_mask = pe.attention_mask.expand(batch_size, -1) return pe def save(self, path: str): @@ -108,7 +120,11 @@ def save(self, path: str): if pe.pooled_embeds is not None: state_dict["pooled_embed"] = pe.pooled_embeds.cpu() if pe.attention_mask is not None: - state_dict["attention_mask"] = pe.attention_mask.cpu() + if isinstance(pe.attention_mask, list) or isinstance(pe.attention_mask, tuple): + for i, attn in enumerate(pe.attention_mask): + state_dict[f"attention_mask_{i}"] = attn.cpu() + else: + state_dict["attention_mask"] = pe.attention_mask.cpu() os.makedirs(os.path.dirname(path), exist_ok=True) save_file(state_dict, path) @@ -122,7 +138,7 @@ def load(cls, path: str) -> 'PromptEmbeds': state_dict = load_file(path, device='cpu') text_embeds = [] pooled_embeds = None - attention_mask = None + attention_mask = [] for key in sorted(state_dict.keys()): if key.startswith("text_embed_"): text_embeds.append(state_dict[key]) @@ -130,19 +146,25 @@ def load(cls, path: str) -> 'PromptEmbeds': text_embeds.append(state_dict[key]) elif key == "pooled_embed": pooled_embeds = state_dict[key] + elif key.startswith("attention_mask_"): + attention_mask.append(state_dict[key]) elif key == "attention_mask": - attention_mask = state_dict[key] + attention_mask.append(state_dict[key]) pe = cls(None) pe.text_embeds = text_embeds if len(text_embeds) == 1: pe.text_embeds = text_embeds[0] if pooled_embeds is not None: pe.pooled_embeds = pooled_embeds - if attention_mask is not None: - pe.attention_mask = attention_mask + if len(attention_mask) > 0: + if len(attention_mask) == 1: + pe.attention_mask = attention_mask[0] + else: + pe.attention_mask = attention_mask return pe + class EncodedPromptPair: def __init__( self, @@ -210,20 +232,63 @@ def detach(self): return self -def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]): - if isinstance(prompt_embeds[0].text_embeds, list) or isinstance(prompt_embeds[0].text_embeds, tuple): +def concat_prompt_embeds(prompt_embeds: list["PromptEmbeds"]): + # --- pad text_embeds --- + if isinstance(prompt_embeds[0].text_embeds, (list, tuple)): embed_list = [] for i in range(len(prompt_embeds[0].text_embeds)): - embed_list.append(torch.cat([p.text_embeds[i] for p in prompt_embeds], dim=0)) + max_len = max(p.text_embeds[i].shape[1] for p in prompt_embeds) + padded = [] + for p in prompt_embeds: + t = p.text_embeds[i] + if t.shape[1] < max_len: + pad = torch.zeros( + (t.shape[0], max_len - t.shape[1], *t.shape[2:]), + dtype=t.dtype, + device=t.device, + ) + t = torch.cat([t, pad], dim=1) + padded.append(t) + embed_list.append(torch.cat(padded, dim=0)) text_embeds = embed_list else: - text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0) + max_len = max(p.text_embeds.shape[1] for p in prompt_embeds) + padded = [] + for p in prompt_embeds: + t = p.text_embeds + if t.shape[1] < max_len: + pad = torch.zeros( + (t.shape[0], max_len - t.shape[1], *t.shape[2:]), + dtype=t.dtype, + device=t.device, + ) + t = torch.cat([t, pad], dim=1) + padded.append(t) + text_embeds = torch.cat(padded, dim=0) + + # --- pooled embeds --- pooled_embeds = None if prompt_embeds[0].pooled_embeds is not None: pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0) + + # --- attention mask --- attention_mask = None if prompt_embeds[0].attention_mask is not None: - attention_mask = torch.cat([p.attention_mask for p in prompt_embeds], dim=0) + max_len = max(p.attention_mask.shape[1] for p in prompt_embeds) + padded = [] + for p in prompt_embeds: + m = p.attention_mask + if m.shape[1] < max_len: + pad = torch.zeros( + (m.shape[0], max_len - m.shape[1]), + dtype=m.dtype, + device=m.device, + ) + m = torch.cat([m, pad], dim=1) + padded.append(m) + attention_mask = torch.cat(padded, dim=0) + + # wrap back into PromptEmbeds pe = PromptEmbeds([text_embeds, pooled_embeds]) pe.attention_mask = attention_mask return pe diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index b042d6998..78960ed1b 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -219,6 +219,10 @@ def __init__( # set true for models that encode control image into text embeddings self.encode_control_in_text_embeddings = False + # control images will come in as a list for encoding some things if true + self.has_multiple_control_images = False + # do not resize control images + self.use_raw_control_images = False # properties for old arch for backwards compatibility @property @@ -1145,6 +1149,8 @@ def generate_images( # the network to drastically speed up inference unique_network_weights = set([x.network_multiplier for x in image_configs]) if len(unique_network_weights) == 1 and network.can_merge_in: + # make sure it is on device before merging. + self.unet.to(self.device_torch) can_merge_in = True merge_multiplier = unique_network_weights.pop() network.merge_in(merge_weight=merge_multiplier) diff --git a/toolkit/util/losses.py b/toolkit/util/losses.py new file mode 100644 index 000000000..5328674ad --- /dev/null +++ b/toolkit/util/losses.py @@ -0,0 +1,93 @@ +import torch + + +_dwt = None + + +def _get_wavelet_loss(device, dtype): + global _dwt + if _dwt is not None: + return _dwt + + # init wavelets + from pytorch_wavelets import DWTForward + + # wave='db1' wave='haar' + dwt = DWTForward(J=1, mode="zero", wave="haar").to(device=device, dtype=dtype) + _dwt = dwt + return dwt + + +def wavelet_loss(model_pred, latents, noise): + model_pred = model_pred.float() + latents = latents.float() + noise = noise.float() + dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype) + with torch.no_grad(): + model_input_xll, model_input_xh = dwt(latents) + model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind( + model_input_xh[0], dim=2 + ) + model_input = torch.cat( + [model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1 + ) + + # reverse the noise to get the model prediction of the pure latents + model_pred = noise - model_pred + + model_pred_xll, model_pred_xh = dwt(model_pred) + model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind( + model_pred_xh[0], dim=2 + ) + model_pred = torch.cat( + [model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1 + ) + + return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none") + + +def stepped_loss(model_pred, latents, noise, noisy_latents, timesteps, scheduler): + # this steps the on a 20 step timescale from the current step (50 idx steps ahead) + # and then reconstructs the original image at that timestep. This should lessen the error + # possible in high noise timesteps and make the flow smoother. + bs = model_pred.shape[0] + + noise_pred_chunks = torch.chunk(model_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + noise_chunks = torch.chunk(noise, bs) + + x0_pred_chunks = [] + + for idx in range(bs): + model_output = noise_pred_chunks[idx] # predicted noise (same shape as latent) + timestep = timestep_chunks[idx] # scalar tensor per sample (e.g., [t]) + sample = noisy_latent_chunks[idx].to(torch.float32) + noise_i = noise_chunks[idx].to(sample.dtype).to(sample.device) + + # Initialize scheduler step index for this sample + scheduler._step_index = None + scheduler._init_step_index(timestep) + + # ---- Step +50 indices (or to the end) in sigma-space ---- + sigma = scheduler.sigmas[scheduler.step_index] + target_idx = min(scheduler.step_index + 50, len(scheduler.sigmas) - 1) + sigma_next = scheduler.sigmas[target_idx] + + # One-step update along the model-predicted direction + stepped = sample + (sigma_next - sigma) * model_output + + # ---- Inverse-Gaussian recovery at the target timestep ---- + t_01 = ( + (scheduler.sigmas[target_idx]).to(stepped.device).to(stepped.dtype) + ) + original_samples = (stepped - t_01 * noise_i) / (1.0 - t_01) + x0_pred_chunks.append(original_samples) + + predicted_images = torch.cat(x0_pred_chunks, dim=0) + + return torch.nn.functional.mse_loss( + predicted_images.float(), + latents.float().to(device=predicted_images.device), + reduction="none", + ) diff --git a/toolkit/util/quantize.py b/toolkit/util/quantize.py index f421190ca..31a96bd15 100644 --- a/toolkit/util/quantize.py +++ b/toolkit/util/quantize.py @@ -261,6 +261,7 @@ def quantize_model( base_model.accuracy_recovery_adapter = network # quantize it + lora_exclude_modules = [] quantization_type = get_qtype(base_model.model_config.qtype) for lora_module in tqdm(network.unet_loras, desc="Attaching quantization"): # the lora has already hijacked the original module @@ -271,10 +272,20 @@ def quantize_model( param.requires_grad = False quantize(orig_module, weights=quantization_type) freeze(orig_module) + module_name = lora_module.lora_name.replace('$$', '.').replace('transformer.', '') + lora_exclude_modules.append(module_name) if base_model.model_config.low_vram: # move it back to cpu orig_module.to("cpu") - + pass + # quantize additional layers + print_acc(" - quantizing additional layers") + quantization_type = get_qtype('uint8') + quantize( + model_to_quantize, + weights=quantization_type, + exclude=lora_exclude_modules + ) else: # quantize model the original way without an accuracy recovery adapter # move and quantize only certain pieces at a time. diff --git a/toolkit/util/wavelet_loss.py b/toolkit/util/wavelet_loss.py deleted file mode 100644 index 8ae9e5a2d..000000000 --- a/toolkit/util/wavelet_loss.py +++ /dev/null @@ -1,39 +0,0 @@ - -import torch - - -_dwt = None - - -def _get_wavelet_loss(device, dtype): - global _dwt - if _dwt is not None: - return _dwt - - # init wavelets - from pytorch_wavelets import DWTForward - # wave='db1' wave='haar' - dwt = DWTForward(J=1, mode='zero', wave='haar').to( - device=device, dtype=dtype) - _dwt = dwt - return dwt - - -def wavelet_loss(model_pred, latents, noise): - model_pred = model_pred.float() - latents = latents.float() - noise = noise.float() - dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype) - with torch.no_grad(): - model_input_xll, model_input_xh = dwt(latents) - model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(model_input_xh[0], dim=2) - model_input = torch.cat([model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1) - - # reverse the noise to get the model prediction of the pure latents - model_pred = noise - model_pred - - model_pred_xll, model_pred_xh = dwt(model_pred) - model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(model_pred_xh[0], dim=2) - model_pred = torch.cat([model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1) - - return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none") \ No newline at end of file diff --git a/ui/next.config.ts b/ui/next.config.ts index 8655fe00c..b0e860a74 100644 --- a/ui/next.config.ts +++ b/ui/next.config.ts @@ -1,6 +1,9 @@ import type { NextConfig } from 'next'; const nextConfig: NextConfig = { + devIndicators: { + buildActivity: false, + }, typescript: { // Remove this. Build fails because of route types ignoreBuildErrors: true, diff --git a/ui/package-lock.json b/ui/package-lock.json index f4ed097b1..e2b6183e5 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -11,6 +11,7 @@ "@headlessui/react": "^2.2.0", "@monaco-editor/react": "^4.7.0", "@prisma/client": "^6.3.1", + "archiver": "^7.0.1", "axios": "^1.7.9", "classnames": "^2.5.1", "lucide-react": "^0.475.0", @@ -28,6 +29,7 @@ "yaml": "^2.7.0" }, "devDependencies": { + "@types/archiver": "^6.0.3", "@types/node": "^20", "@types/react": "^19", "@types/react-dom": "^19", @@ -725,7 +727,6 @@ "version": "8.0.2", "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz", "integrity": "sha512-O8jcjabXaleOG9DQ0+ARXWZBTfnP4WNAqzuiJK7ll44AmxGKv/J2M4TPjxjY3znBCfvBXFzucm1twdyFybFqEA==", - "dev": true, "dependencies": { "string-width": "^5.1.2", "string-width-cjs": "npm:string-width@^4.2.0", @@ -992,7 +993,6 @@ "version": "0.11.0", "resolved": "https://registry.npmjs.org/@pkgjs/parseargs/-/parseargs-0.11.0.tgz", "integrity": "sha512-+1VkjdD0QBLPodGrJUeqarH8VAIvQODIbwh9XpP5Syisf7YoQgsJKPNFoqqLQlu+VQ/tVSshMR6loPMn8U+dPg==", - "dev": true, "optional": true, "engines": { "node": ">=14" @@ -1217,6 +1217,15 @@ "integrity": "sha512-vxhUy4J8lyeyinH7Azl1pdd43GJhZH/tP2weN8TntQblOY+A0XbT8DJk1/oCPuOOyg/Ja757rG0CgHcWC8OfMA==", "dev": true }, + "node_modules/@types/archiver": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/@types/archiver/-/archiver-6.0.3.tgz", + "integrity": "sha512-a6wUll6k3zX6qs5KlxIggs1P1JcYJaTCx2gnlr+f0S1yd2DoaEwoIK10HmBaLnZwWneBz+JBm0dwcZu0zECBcQ==", + "dev": true, + "dependencies": { + "@types/readdir-glob": "*" + } + }, "node_modules/@types/node": { "version": "20.17.19", "resolved": "https://registry.npmjs.org/@types/node/-/node-20.17.19.tgz", @@ -1256,6 +1265,15 @@ "@types/react": "*" } }, + "node_modules/@types/readdir-glob": { + "version": "1.1.5", + "resolved": "https://registry.npmjs.org/@types/readdir-glob/-/readdir-glob-1.1.5.tgz", + "integrity": "sha512-raiuEPUYqXu+nvtY2Pe8s8FEmZ3x5yAH4VkLdihcPdalvsHltomrRC9BzuStrJ9yk06470hS0Crw0f1pXqD+Hg==", + "dev": true, + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@types/strip-bom": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/@types/strip-bom/-/strip-bom-3.0.0.tgz", @@ -1275,6 +1293,17 @@ "license": "ISC", "optional": true }, + "node_modules/abort-controller": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/abort-controller/-/abort-controller-3.0.0.tgz", + "integrity": "sha512-h8lQ8tacZYnR3vNQTgibj+tODHI5/+l06Au2Pcriv/Gmet0eaj4TwWH41sO9wnHDiQsEj19q0drzdWdeAHtweg==", + "dependencies": { + "event-target-shim": "^5.0.0" + }, + "engines": { + "node": ">=6.5" + } + }, "node_modules/acorn": { "version": "8.15.0", "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", @@ -1343,7 +1372,6 @@ "version": "6.1.0", "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.1.0.tgz", "integrity": "sha512-7HSX4QQb4CspciLpVFwyRe79O3xsIZDDLER21kERQ71oaPodF8jL725AgJMFAYbooIqolJoRLuM81SpeUkpkvA==", - "dev": true, "engines": { "node": ">=12" }, @@ -1355,7 +1383,6 @@ "version": "6.2.1", "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-6.2.1.tgz", "integrity": "sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug==", - "dev": true, "engines": { "node": ">=12" }, @@ -1389,6 +1416,126 @@ "license": "ISC", "optional": true }, + "node_modules/archiver": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/archiver/-/archiver-7.0.1.tgz", + "integrity": "sha512-ZcbTaIqJOfCc03QwD468Unz/5Ir8ATtvAHsK+FdXbDIbGfihqh9mrvdcYunQzqn4HrvWWaFyaxJhGZagaJJpPQ==", + "dependencies": { + "archiver-utils": "^5.0.2", + "async": "^3.2.4", + "buffer-crc32": "^1.0.0", + "readable-stream": "^4.0.0", + "readdir-glob": "^1.1.2", + "tar-stream": "^3.0.0", + "zip-stream": "^6.0.1" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/archiver-utils": { + "version": "5.0.2", + "resolved": "https://registry.npmjs.org/archiver-utils/-/archiver-utils-5.0.2.tgz", + "integrity": "sha512-wuLJMmIBQYCsGZgYLTy5FIB2pF6Lfb6cXMSF8Qywwk3t20zWnAi7zLcQFdKQmIB8wyZpY5ER38x08GbwtR2cLA==", + "dependencies": { + "glob": "^10.0.0", + "graceful-fs": "^4.2.0", + "is-stream": "^2.0.1", + "lazystream": "^1.0.0", + "lodash": "^4.17.15", + "normalize-path": "^3.0.0", + "readable-stream": "^4.0.0" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/archiver-utils/node_modules/buffer": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/buffer/-/buffer-6.0.3.tgz", + "integrity": "sha512-FTiCpNxtwiZZHEZbcbTIcZjERVICn9yq/pDFkTl95/AxzD1naBctN7YO68riM/gLSDY7sdrMby8hofADYuuqOA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "base64-js": "^1.3.1", + "ieee754": "^1.2.1" + } + }, + "node_modules/archiver-utils/node_modules/readable-stream": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.7.0.tgz", + "integrity": "sha512-oIGGmcpTLwPga8Bn6/Z75SVaH1z5dUut2ibSyAMVhmUggWpmDn2dapB0n7f8nwaSiRtepAsfJyfXIO5DCVAODg==", + "dependencies": { + "abort-controller": "^3.0.0", + "buffer": "^6.0.3", + "events": "^3.3.0", + "process": "^0.11.10", + "string_decoder": "^1.3.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + } + }, + "node_modules/archiver/node_modules/buffer": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/buffer/-/buffer-6.0.3.tgz", + "integrity": "sha512-FTiCpNxtwiZZHEZbcbTIcZjERVICn9yq/pDFkTl95/AxzD1naBctN7YO68riM/gLSDY7sdrMby8hofADYuuqOA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "base64-js": "^1.3.1", + "ieee754": "^1.2.1" + } + }, + "node_modules/archiver/node_modules/readable-stream": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.7.0.tgz", + "integrity": "sha512-oIGGmcpTLwPga8Bn6/Z75SVaH1z5dUut2ibSyAMVhmUggWpmDn2dapB0n7f8nwaSiRtepAsfJyfXIO5DCVAODg==", + "dependencies": { + "abort-controller": "^3.0.0", + "buffer": "^6.0.3", + "events": "^3.3.0", + "process": "^0.11.10", + "string_decoder": "^1.3.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + } + }, + "node_modules/archiver/node_modules/tar-stream": { + "version": "3.1.7", + "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-3.1.7.tgz", + "integrity": "sha512-qJj60CXt7IU1Ffyc3NJMjh6EkuCFej46zUqJ4J7pqYlThyd9bO0XBTmcOIhSzZJVWfsLks0+nle/j538YAW9RQ==", + "dependencies": { + "b4a": "^1.6.4", + "fast-fifo": "^1.2.0", + "streamx": "^2.15.0" + } + }, "node_modules/are-we-there-yet": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/are-we-there-yet/-/are-we-there-yet-3.0.1.tgz", @@ -1410,6 +1557,11 @@ "integrity": "sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg==", "dev": true }, + "node_modules/async": { + "version": "3.2.6", + "resolved": "https://registry.npmjs.org/async/-/async-3.2.6.tgz", + "integrity": "sha512-htCUDlxyyCLMgaM3xXg0C0LW2xqfuQ6p05pCEIsXuyQ+a1koYKTuBMzRNwmybfLgvJDMd0r1LTn4+E0Ti6C2AA==" + }, "node_modules/asynckit": { "version": "0.4.0", "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", @@ -1433,6 +1585,11 @@ "proxy-from-env": "^1.1.0" } }, + "node_modules/b4a": { + "version": "1.6.7", + "resolved": "https://registry.npmjs.org/b4a/-/b4a-1.6.7.tgz", + "integrity": "sha512-OnAYlL5b7LEkALw87fUVafQw5rVR9RjwGd4KUwNQ6DrrNmaVaUCgLipfVlzrPQ4tWOR9P0IXGNOx50jYCCdSJg==" + }, "node_modules/babel-plugin-macros": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/babel-plugin-macros/-/babel-plugin-macros-3.1.0.tgz", @@ -1450,8 +1607,13 @@ "node_modules/balanced-match": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", - "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", - "devOptional": true + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==" + }, + "node_modules/bare-events": { + "version": "2.6.1", + "resolved": "https://registry.npmjs.org/bare-events/-/bare-events-2.6.1.tgz", + "integrity": "sha512-AuTJkq9XmE6Vk0FJVNq5QxETrSA/vKHarWVBG5l/JbdCL1prJemiyJqUS0jrlXO0MftuPq4m3YVYhoNc5+aE/g==", + "optional": true }, "node_modules/base64-js": { "version": "1.5.1", @@ -1509,7 +1671,6 @@ "version": "2.0.1", "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", - "dev": true, "dependencies": { "balanced-match": "^1.0.0" } @@ -1550,6 +1711,14 @@ "ieee754": "^1.1.13" } }, + "node_modules/buffer-crc32": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/buffer-crc32/-/buffer-crc32-1.0.0.tgz", + "integrity": "sha512-Db1SbgBS/fg/392AblrMJk97KggmvYhr4pB5ZIMTWtaivCPMWLkmb7m21cJvpvgK+J3nsU2CmmixNBZx4vFj/w==", + "engines": { + "node": ">=8.0.0" + } + }, "node_modules/buffer-from": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz", @@ -1945,7 +2114,6 @@ "version": "2.0.1", "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "devOptional": true, "dependencies": { "color-name": "~1.1.4" }, @@ -1956,8 +2124,7 @@ "node_modules/color-name": { "version": "1.1.4", "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", - "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "devOptional": true + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==" }, "node_modules/color-string": { "version": "1.9.1", @@ -1999,6 +2166,59 @@ "node": ">= 6" } }, + "node_modules/compress-commons": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/compress-commons/-/compress-commons-6.0.2.tgz", + "integrity": "sha512-6FqVXeETqWPoGcfzrXb37E50NP0LXT8kAMu5ooZayhWWdgEY4lBEEcbQNXtkuKQsGduxiIcI4gOTsxTmuq/bSg==", + "dependencies": { + "crc-32": "^1.2.0", + "crc32-stream": "^6.0.0", + "is-stream": "^2.0.1", + "normalize-path": "^3.0.0", + "readable-stream": "^4.0.0" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/compress-commons/node_modules/buffer": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/buffer/-/buffer-6.0.3.tgz", + "integrity": "sha512-FTiCpNxtwiZZHEZbcbTIcZjERVICn9yq/pDFkTl95/AxzD1naBctN7YO68riM/gLSDY7sdrMby8hofADYuuqOA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "base64-js": "^1.3.1", + "ieee754": "^1.2.1" + } + }, + "node_modules/compress-commons/node_modules/readable-stream": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.7.0.tgz", + "integrity": "sha512-oIGGmcpTLwPga8Bn6/Z75SVaH1z5dUut2ibSyAMVhmUggWpmDn2dapB0n7f8nwaSiRtepAsfJyfXIO5DCVAODg==", + "dependencies": { + "abort-controller": "^3.0.0", + "buffer": "^6.0.3", + "events": "^3.3.0", + "process": "^0.11.10", + "string_decoder": "^1.3.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + } + }, "node_modules/concat-map": { "version": "0.0.1", "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", @@ -2043,6 +2263,11 @@ "resolved": "https://registry.npmjs.org/convert-source-map/-/convert-source-map-1.9.0.tgz", "integrity": "sha512-ASFBup0Mz1uyiIjANan1jzLQami9z1PoYSZCiiYW2FczPbenXc45FZdBZLzOT+r6+iciuEModtmCti+hjaAk0A==" }, + "node_modules/core-util-is": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/core-util-is/-/core-util-is-1.0.3.tgz", + "integrity": "sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==" + }, "node_modules/cosmiconfig": { "version": "7.1.0", "resolved": "https://registry.npmjs.org/cosmiconfig/-/cosmiconfig-7.1.0.tgz", @@ -2066,6 +2291,67 @@ "node": ">= 6" } }, + "node_modules/crc-32": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/crc-32/-/crc-32-1.2.2.tgz", + "integrity": "sha512-ROmzCKrTnOwybPcJApAA6WBWij23HVfGVNKqqrZpuyZOHqK2CwHSvpGuyt/UNNvaIjEd8X5IFGp4Mh+Ie1IHJQ==", + "bin": { + "crc32": "bin/crc32.njs" + }, + "engines": { + "node": ">=0.8" + } + }, + "node_modules/crc32-stream": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/crc32-stream/-/crc32-stream-6.0.0.tgz", + "integrity": "sha512-piICUB6ei4IlTv1+653yq5+KoqfBYmj9bw6LqXoOneTMDXk5nM1qt12mFW1caG3LlJXEKW1Bp0WggEmIfQB34g==", + "dependencies": { + "crc-32": "^1.2.0", + "readable-stream": "^4.0.0" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/crc32-stream/node_modules/buffer": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/buffer/-/buffer-6.0.3.tgz", + "integrity": "sha512-FTiCpNxtwiZZHEZbcbTIcZjERVICn9yq/pDFkTl95/AxzD1naBctN7YO68riM/gLSDY7sdrMby8hofADYuuqOA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "base64-js": "^1.3.1", + "ieee754": "^1.2.1" + } + }, + "node_modules/crc32-stream/node_modules/readable-stream": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.7.0.tgz", + "integrity": "sha512-oIGGmcpTLwPga8Bn6/Z75SVaH1z5dUut2ibSyAMVhmUggWpmDn2dapB0n7f8nwaSiRtepAsfJyfXIO5DCVAODg==", + "dependencies": { + "abort-controller": "^3.0.0", + "buffer": "^6.0.3", + "events": "^3.3.0", + "process": "^0.11.10", + "string_decoder": "^1.3.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + } + }, "node_modules/create-require": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz", @@ -2076,7 +2362,6 @@ "version": "7.0.6", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", - "dev": true, "dependencies": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", @@ -2222,14 +2507,12 @@ "node_modules/eastasianwidth": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/eastasianwidth/-/eastasianwidth-0.2.0.tgz", - "integrity": "sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA==", - "dev": true + "integrity": "sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA==" }, "node_modules/emoji-regex": { "version": "9.2.2", "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz", - "integrity": "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==", - "dev": true + "integrity": "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==" }, "node_modules/encoding": { "version": "0.1.13", @@ -2341,6 +2624,22 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/event-target-shim": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/event-target-shim/-/event-target-shim-5.0.1.tgz", + "integrity": "sha512-i/2XbnSz/uxRCU6+NdVJgKWDTM427+MqYbkQzD321DuCQJUqOuJKIA0IM2+W2xtYHdKOmZ4dR6fExsd4SXL+WQ==", + "engines": { + "node": ">=6" + } + }, + "node_modules/events": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/events/-/events-3.3.0.tgz", + "integrity": "sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==", + "engines": { + "node": ">=0.8.x" + } + }, "node_modules/expand-template": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/expand-template/-/expand-template-2.0.3.tgz", @@ -2350,6 +2649,11 @@ "node": ">=6" } }, + "node_modules/fast-fifo": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/fast-fifo/-/fast-fifo-1.3.2.tgz", + "integrity": "sha512-/d9sfos4yxzpwkDkuN7k2SqFKtYNmCTzgfEpz82x34IM9/zc8KGxQoXg1liNC/izpRM/MBdt44Nmx41ZWqk+FQ==" + }, "node_modules/fast-glob": { "version": "3.3.3", "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", @@ -2444,7 +2748,6 @@ "version": "3.3.0", "resolved": "https://registry.npmjs.org/foreground-child/-/foreground-child-3.3.0.tgz", "integrity": "sha512-Ld2g8rrAyMYFXBhEqMz8ZAHBi4J4uS1i/CxGMDnjyFWddMXLVcDp051DZfu+t7+ab7Wv6SMqpWmyFIj5UbfFvg==", - "dev": true, "dependencies": { "cross-spawn": "^7.0.0", "signal-exit": "^4.0.1" @@ -2655,7 +2958,6 @@ "version": "10.4.5", "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", - "dev": true, "dependencies": { "foreground-child": "^3.1.0", "jackspeak": "^3.1.2", @@ -2706,8 +3008,7 @@ "version": "4.2.11", "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", - "license": "ISC", - "optional": true + "license": "ISC" }, "node_modules/has-flag": { "version": "4.0.0", @@ -2973,7 +3274,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", - "devOptional": true, "engines": { "node": ">=8" } @@ -3006,17 +3306,31 @@ "node": ">=0.12.0" } }, + "node_modules/is-stream": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-2.0.1.tgz", + "integrity": "sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==", + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/isarray": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-1.0.0.tgz", + "integrity": "sha512-VLghIWNM6ELQzo7zwmcg0NmTVyWKYjvIeM83yjp0wRDTmUnrM678fQbcKBo6n2CJEF0szoG//ytg+TKla89ALQ==" + }, "node_modules/isexe": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", - "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", - "devOptional": true + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==" }, "node_modules/jackspeak": { "version": "3.4.3", "resolved": "https://registry.npmjs.org/jackspeak/-/jackspeak-3.4.3.tgz", "integrity": "sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw==", - "dev": true, "dependencies": { "@isaacs/cliui": "^8.0.2" }, @@ -3064,6 +3378,44 @@ "resolved": "https://registry.npmjs.org/json-parse-even-better-errors/-/json-parse-even-better-errors-2.3.1.tgz", "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==" }, + "node_modules/lazystream": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/lazystream/-/lazystream-1.0.1.tgz", + "integrity": "sha512-b94GiNHQNy6JNTrt5w6zNyffMrNkXZb3KTkCZJb2V1xaEGCk093vkZ2jk3tpaeP33/OiXC+WvK9AxUebnf5nbw==", + "dependencies": { + "readable-stream": "^2.0.5" + }, + "engines": { + "node": ">= 0.6.3" + } + }, + "node_modules/lazystream/node_modules/readable-stream": { + "version": "2.3.8", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-2.3.8.tgz", + "integrity": "sha512-8p0AUk4XODgIewSi0l8Epjs+EVnWiK7NoDIEGU0HhE7+ZyY8D1IMY7odu5lRrFXGg71L15KG8QrPmum45RTtdA==", + "dependencies": { + "core-util-is": "~1.0.0", + "inherits": "~2.0.3", + "isarray": "~1.0.0", + "process-nextick-args": "~2.0.0", + "safe-buffer": "~5.1.1", + "string_decoder": "~1.1.1", + "util-deprecate": "~1.0.1" + } + }, + "node_modules/lazystream/node_modules/safe-buffer": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz", + "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==" + }, + "node_modules/lazystream/node_modules/string_decoder": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.1.1.tgz", + "integrity": "sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==", + "dependencies": { + "safe-buffer": "~5.1.0" + } + }, "node_modules/lilconfig": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-3.1.3.tgz", @@ -3084,8 +3436,7 @@ "node_modules/lodash": { "version": "4.17.21", "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", - "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", - "dev": true + "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==" }, "node_modules/loose-envify": { "version": "1.4.0", @@ -3101,8 +3452,7 @@ "node_modules/lru-cache": { "version": "10.4.3", "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-10.4.3.tgz", - "integrity": "sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ==", - "dev": true + "integrity": "sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ==" }, "node_modules/lucide-react": { "version": "0.475.0", @@ -3243,7 +3593,6 @@ "version": "9.0.5", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", - "dev": true, "dependencies": { "brace-expansion": "^2.0.1" }, @@ -3267,7 +3616,6 @@ "version": "7.1.2", "resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.2.tgz", "integrity": "sha512-qOOzS1cBTWYF4BH8fVePDBOO9iptMnGUEZwNc/cMWnTV2nVLZ7VoNWEPHkYczZA0pdoA7dl6e7FL659nX9S2aw==", - "dev": true, "engines": { "node": ">=16 || 14 >=14.17" } @@ -3706,7 +4054,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", - "dev": true, "engines": { "node": ">=0.10.0" } @@ -3773,8 +4120,7 @@ "node_modules/package-json-from-dist": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/package-json-from-dist/-/package-json-from-dist-1.0.1.tgz", - "integrity": "sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==", - "dev": true + "integrity": "sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==" }, "node_modules/parent-module": { "version": "1.0.1", @@ -3818,7 +4164,6 @@ "version": "3.1.1", "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", - "dev": true, "engines": { "node": ">=8" } @@ -3832,7 +4177,6 @@ "version": "1.11.1", "resolved": "https://registry.npmjs.org/path-scurry/-/path-scurry-1.11.1.tgz", "integrity": "sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA==", - "dev": true, "dependencies": { "lru-cache": "^10.2.0", "minipass": "^5.0.0 || ^6.0.2 || ^7.0.0" @@ -4106,6 +4450,19 @@ } } }, + "node_modules/process": { + "version": "0.11.10", + "resolved": "https://registry.npmjs.org/process/-/process-0.11.10.tgz", + "integrity": "sha512-cdGef/drWFoydD1JsMzuFf8100nZl+GT+yacc2bEced5f9Rjk4z+WtFUTBu9PhOi9j/jfmBPu0mMEY4wIdAF8A==", + "engines": { + "node": ">= 0.6.0" + } + }, + "node_modules/process-nextick-args": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/process-nextick-args/-/process-nextick-args-2.0.1.tgz", + "integrity": "sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==" + }, "node_modules/promise-inflight": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/promise-inflight/-/promise-inflight-1.0.1.tgz", @@ -4301,6 +4658,25 @@ "node": ">= 6" } }, + "node_modules/readdir-glob": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/readdir-glob/-/readdir-glob-1.1.3.tgz", + "integrity": "sha512-v05I2k7xN8zXvPD9N+z/uhXPaj0sUFCe2rcWZIpBsqxfP7xXFQ0tipAd/wjj1YxWyWtUS5IDJpOG82JKt2EAVA==", + "dependencies": { + "minimatch": "^5.1.0" + } + }, + "node_modules/readdir-glob/node_modules/minimatch": { + "version": "5.1.6", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", + "integrity": "sha512-lKwV/1brpG6mBUFHtb7NUmtABCb2WZZmm2wNiOA5hAb8VdCS4B3dtMWyvcoViccwAW/COERjXLt0zP1zXUN26g==", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=10" + } + }, "node_modules/readdirp": { "version": "3.6.0", "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", @@ -4557,7 +4933,6 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", - "dev": true, "dependencies": { "shebang-regex": "^3.0.0" }, @@ -4569,7 +4944,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", - "dev": true, "engines": { "node": ">=8" } @@ -4590,7 +4964,6 @@ "version": "4.1.0", "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz", "integrity": "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==", - "dev": true, "engines": { "node": ">=14" }, @@ -4798,6 +5171,18 @@ "node": ">=10.0.0" } }, + "node_modules/streamx": { + "version": "2.22.1", + "resolved": "https://registry.npmjs.org/streamx/-/streamx-2.22.1.tgz", + "integrity": "sha512-znKXEBxfatz2GBNK02kRnCXjV+AA4kjZIUxeWSr3UGirZMJfTE9uiwKHobnbgxWyL/JWro8tTq+vOqAK1/qbSA==", + "dependencies": { + "fast-fifo": "^1.3.2", + "text-decoder": "^1.1.0" + }, + "optionalDependencies": { + "bare-events": "^2.2.0" + } + }, "node_modules/string_decoder": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz", @@ -4811,7 +5196,6 @@ "version": "5.1.2", "resolved": "https://registry.npmjs.org/string-width/-/string-width-5.1.2.tgz", "integrity": "sha512-HnLOCR3vjcY8beoNLtcjZ5/nxn2afmME6lhrDrebokqMap+XbeW8n9TXpPDOqdGK5qcI3oT0GKTW6wC7EMiVqA==", - "dev": true, "dependencies": { "eastasianwidth": "^0.2.0", "emoji-regex": "^9.2.2", @@ -4829,7 +5213,6 @@ "version": "4.2.3", "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", - "dev": true, "dependencies": { "emoji-regex": "^8.0.0", "is-fullwidth-code-point": "^3.0.0", @@ -4843,7 +5226,6 @@ "version": "5.0.1", "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", - "dev": true, "engines": { "node": ">=8" } @@ -4851,14 +5233,12 @@ "node_modules/string-width-cjs/node_modules/emoji-regex": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", - "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", - "dev": true + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==" }, "node_modules/string-width-cjs/node_modules/strip-ansi": { "version": "6.0.1", "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", - "dev": true, "dependencies": { "ansi-regex": "^5.0.1" }, @@ -4870,7 +5250,6 @@ "version": "7.1.0", "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.1.0.tgz", "integrity": "sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ==", - "dev": true, "dependencies": { "ansi-regex": "^6.0.1" }, @@ -4886,7 +5265,6 @@ "version": "6.0.1", "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", - "dev": true, "dependencies": { "ansi-regex": "^5.0.1" }, @@ -4898,7 +5276,6 @@ "version": "5.0.1", "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", - "dev": true, "engines": { "node": ">=8" } @@ -5098,6 +5475,14 @@ "node": ">=8" } }, + "node_modules/text-decoder": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/text-decoder/-/text-decoder-1.2.3.tgz", + "integrity": "sha512-3/o9z3X0X0fTupwsYvR03pJ/DjWuqqrfwBgTQzdWDiQSm9KitAyz/9WqsT2JQW7KV2m+bC2ol/zqpW37NHxLaA==", + "dependencies": { + "b4a": "^1.6.4" + } + }, "node_modules/thenify": { "version": "3.3.1", "resolved": "https://registry.npmjs.org/thenify/-/thenify-3.3.1.tgz", @@ -5393,7 +5778,6 @@ "version": "2.0.2", "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", - "devOptional": true, "dependencies": { "isexe": "^2.0.0" }, @@ -5463,7 +5847,6 @@ "version": "8.1.0", "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-8.1.0.tgz", "integrity": "sha512-si7QWI6zUMq56bESFvagtmzMdGOtoxfR+Sez11Mobfc7tm+VkUckk9bW2UeffTGVUbOksxmSw0AA2gs8g71NCQ==", - "dev": true, "dependencies": { "ansi-styles": "^6.1.0", "string-width": "^5.0.1", @@ -5481,7 +5864,6 @@ "version": "7.0.0", "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz", "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==", - "dev": true, "dependencies": { "ansi-styles": "^4.0.0", "string-width": "^4.1.0", @@ -5498,7 +5880,6 @@ "version": "5.0.1", "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", - "dev": true, "engines": { "node": ">=8" } @@ -5507,7 +5888,6 @@ "version": "4.3.0", "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, "dependencies": { "color-convert": "^2.0.1" }, @@ -5521,14 +5901,12 @@ "node_modules/wrap-ansi-cjs/node_modules/emoji-regex": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", - "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", - "dev": true + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==" }, "node_modules/wrap-ansi-cjs/node_modules/string-width": { "version": "4.2.3", "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", - "dev": true, "dependencies": { "emoji-regex": "^8.0.0", "is-fullwidth-code-point": "^3.0.0", @@ -5542,7 +5920,6 @@ "version": "6.0.1", "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", - "dev": true, "dependencies": { "ansi-regex": "^5.0.1" }, @@ -5667,6 +6044,57 @@ "engines": { "node": ">=6" } + }, + "node_modules/zip-stream": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/zip-stream/-/zip-stream-6.0.1.tgz", + "integrity": "sha512-zK7YHHz4ZXpW89AHXUPbQVGKI7uvkd3hzusTdotCg1UxyaVtg0zFJSTfW/Dq5f7OBBVnq6cZIaC8Ti4hb6dtCA==", + "dependencies": { + "archiver-utils": "^5.0.0", + "compress-commons": "^6.0.2", + "readable-stream": "^4.0.0" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/zip-stream/node_modules/buffer": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/buffer/-/buffer-6.0.3.tgz", + "integrity": "sha512-FTiCpNxtwiZZHEZbcbTIcZjERVICn9yq/pDFkTl95/AxzD1naBctN7YO68riM/gLSDY7sdrMby8hofADYuuqOA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "base64-js": "^1.3.1", + "ieee754": "^1.2.1" + } + }, + "node_modules/zip-stream/node_modules/readable-stream": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.7.0.tgz", + "integrity": "sha512-oIGGmcpTLwPga8Bn6/Z75SVaH1z5dUut2ibSyAMVhmUggWpmDn2dapB0n7f8nwaSiRtepAsfJyfXIO5DCVAODg==", + "dependencies": { + "abort-controller": "^3.0.0", + "buffer": "^6.0.3", + "events": "^3.3.0", + "process": "^0.11.10", + "string_decoder": "^1.3.0" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + } } } } diff --git a/ui/package.json b/ui/package.json index 594aac6dc..fbc06ef2e 100644 --- a/ui/package.json +++ b/ui/package.json @@ -15,6 +15,7 @@ "@headlessui/react": "^2.2.0", "@monaco-editor/react": "^4.7.0", "@prisma/client": "^6.3.1", + "archiver": "^7.0.1", "axios": "^1.7.9", "classnames": "^2.5.1", "lucide-react": "^0.475.0", @@ -32,6 +33,7 @@ "yaml": "^2.7.0" }, "devDependencies": { + "@types/archiver": "^6.0.3", "@types/node": "^20", "@types/react": "^19", "@types/react-dom": "^19", diff --git a/ui/src/app/api/files/[...filePath]/route.ts b/ui/src/app/api/files/[...filePath]/route.ts index 72c2312d4..46eb5c4ab 100644 --- a/ui/src/app/api/files/[...filePath]/route.ts +++ b/ui/src/app/api/files/[...filePath]/route.ts @@ -50,6 +50,7 @@ export async function GET(request: NextRequest, { params }: { params: { filePath '.svg': 'image/svg+xml', '.bmp': 'image/bmp', '.safetensors': 'application/octet-stream', + '.zip': 'application/zip', // Videos '.mp4': 'video/mp4', '.avi': 'video/x-msvideo', diff --git a/ui/src/app/api/img/delete/route.ts b/ui/src/app/api/img/delete/route.ts index d4d968f8e..d213c1c68 100644 --- a/ui/src/app/api/img/delete/route.ts +++ b/ui/src/app/api/img/delete/route.ts @@ -1,17 +1,24 @@ import { NextResponse } from 'next/server'; import fs from 'fs'; -import { getDatasetsRoot } from '@/server/settings'; +import { getDatasetsRoot, getTrainingFolder } from '@/server/settings'; export async function POST(request: Request) { try { const body = await request.json(); const { imgPath } = body; let datasetsPath = await getDatasetsRoot(); + const trainingPath = await getTrainingFolder(); + // make sure the dataset path is in the image path - if (!imgPath.startsWith(datasetsPath)) { + if (!imgPath.startsWith(datasetsPath) && !imgPath.startsWith(trainingPath)) { return NextResponse.json({ error: 'Invalid image path' }, { status: 400 }); } + // make sure it is an image + if (!/\.(jpg|jpeg|png|bmp|gif|tiff|webp)$/i.test(imgPath.toLowerCase())) { + return NextResponse.json({ error: 'Not an image' }, { status: 400 }); + } + // if img doesnt exist, ignore if (!fs.existsSync(imgPath)) { return NextResponse.json({ success: true }); diff --git a/ui/src/app/api/jobs/[jobID]/mark_stopped/route.ts b/ui/src/app/api/jobs/[jobID]/mark_stopped/route.ts new file mode 100644 index 000000000..705c36993 --- /dev/null +++ b/ui/src/app/api/jobs/[jobID]/mark_stopped/route.ts @@ -0,0 +1,26 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + // update job status to 'running' + await prisma.job.update({ + where: { id: jobID }, + data: { + stop: true, + status: 'stopped', + info: 'Job stopped', + }, + }); + + console.log(`Job ${jobID} marked as stopped`); + + return NextResponse.json(job); +} diff --git a/ui/src/app/api/zip/route.ts b/ui/src/app/api/zip/route.ts new file mode 100644 index 000000000..fc4b946da --- /dev/null +++ b/ui/src/app/api/zip/route.ts @@ -0,0 +1,78 @@ +/* eslint-disable */ +import { NextRequest, NextResponse } from 'next/server'; +import fs from 'fs'; +import fsp from 'fs/promises'; +import path from 'path'; +import archiver from 'archiver'; +import { getTrainingFolder } from '@/server/settings'; + +export const runtime = 'nodejs'; // ensure Node APIs are available +export const dynamic = 'force-dynamic'; // long-running, non-cached + +type PostBody = { + zipTarget: 'samples'; //only samples for now + jobName: string; +}; + +async function resolveSafe(p: string) { + // resolve symlinks + normalize + return await fsp.realpath(p); +} + +export async function POST(request: NextRequest) { + try { + const body = (await request.json()) as PostBody; + if (!body || !body.jobName) { + return NextResponse.json({ error: 'jobName is required' }, { status: 400 }); + } + + const trainingRoot = await resolveSafe(await getTrainingFolder()); + const folderPath = await resolveSafe(path.join(trainingRoot, body.jobName, 'samples')); + const outputPath = path.resolve(trainingRoot, body.jobName, 'samples.zip'); + + // Must be a directory + let stat: fs.Stats; + try { + stat = await fsp.stat(folderPath); + } catch { + return new NextResponse('Folder not found', { status: 404 }); + } + if (!stat.isDirectory()) { + return new NextResponse('Not a directory', { status: 400 }); + } + + // delete current one if it exists + if (fs.existsSync(outputPath)) { + await fsp.unlink(outputPath); + } + + // Create write stream & archive + await new Promise((resolve, reject) => { + const output = fs.createWriteStream(outputPath); + const archive = archiver('zip', { zlib: { level: 9 } }); + + output.on('close', () => resolve()); + output.on('error', reject); + archive.on('error', reject); + + archive.pipe(output); + + // Add the directory contents (place them under the folder's base name in the zip) + const rootName = path.basename(folderPath); + archive.directory(folderPath, rootName); + + archive.finalize().catch(reject); + }); + + // Return the absolute path so your existing /api/files/[...filePath] can serve it + // Example download URL (client-side): `/api/files/${encodeURIComponent(resolvedOutPath)}` + return NextResponse.json({ + ok: true, + zipPath: outputPath, + fileName: path.basename(outputPath), + }); + } catch (err) { + console.error('Zip error:', err); + return new NextResponse('Internal Server Error', { status: 500 }); + } +} diff --git a/ui/src/app/datasets/[datasetName]/page.tsx b/ui/src/app/datasets/[datasetName]/page.tsx index d3bf68fcb..dcba47360 100644 --- a/ui/src/app/datasets/[datasetName]/page.tsx +++ b/ui/src/app/datasets/[datasetName]/page.tsx @@ -1,12 +1,14 @@ 'use client'; -import { useEffect, useState, use } from 'react'; +import { useEffect, useState, use, useMemo } from 'react'; +import { LuImageOff, LuLoader, LuBan } from 'react-icons/lu'; import { FaChevronLeft } from 'react-icons/fa'; import DatasetImageCard from '@/components/DatasetImageCard'; import { Button } from '@headlessui/react'; import AddImagesModal, { openImagesModal } from '@/components/AddImagesModal'; import { TopBar, MainContent } from '@/components/layout'; import { apiClient } from '@/utils/api'; +import FullscreenDropOverlay from '@/components/FullscreenDropOverlay'; export default function DatasetPage({ params }: { params: { datasetName: string } }) { const [imgList, setImgList] = useState<{ img_path: string }[]>([]); @@ -38,6 +40,56 @@ export default function DatasetPage({ params }: { params: { datasetName: string } }, [datasetName]); + const PageInfoContent = useMemo(() => { + let icon = null; + let text = ''; + let subtitle = ''; + let showIt = false; + let bgColor = ''; + let textColor = ''; + let iconColor = ''; + + if (status == 'loading') { + icon = ; + text = 'Loading Images'; + subtitle = 'Please wait while we fetch your dataset images...'; + showIt = true; + bgColor = 'bg-gray-50 dark:bg-gray-800/50'; + textColor = 'text-gray-900 dark:text-gray-100'; + iconColor = 'text-gray-500 dark:text-gray-400'; + } + if (status == 'error') { + icon = ; + text = 'Error Loading Images'; + subtitle = 'There was a problem fetching the images. Please try refreshing the page.'; + showIt = true; + bgColor = 'bg-red-50 dark:bg-red-950/20'; + textColor = 'text-red-900 dark:text-red-100'; + iconColor = 'text-red-600 dark:text-red-400'; + } + if (status == 'success' && imgList.length === 0) { + icon = ; + text = 'No Images Found'; + subtitle = 'This dataset is empty. Click "Add Images" to get started.'; + showIt = true; + bgColor = 'bg-gray-50 dark:bg-gray-800/50'; + textColor = 'text-gray-900 dark:text-gray-100'; + iconColor = 'text-gray-500 dark:text-gray-400'; + } + + if (!showIt) return null; + + return ( +
+
{icon}
+

{text}

+

{subtitle}

+
+ ); + }, [status, imgList.length]); + return ( <> {/* Fixed top bar */} @@ -61,11 +113,9 @@ export default function DatasetPage({ params }: { params: { datasetName: string - {status === 'loading' &&

Loading...

} - {status === 'error' &&

Error fetching images

} - {status === 'success' && ( + {PageInfoContent} + {status === 'success' && imgList.length > 0 && (
- {imgList.length === 0 &&

No images found

} {imgList.map(img => ( + refreshImageList(datasetName)} + /> ); } diff --git a/ui/src/app/jobs/[jobID]/page.tsx b/ui/src/app/jobs/[jobID]/page.tsx index 5a6d5d59e..96fc19368 100644 --- a/ui/src/app/jobs/[jobID]/page.tsx +++ b/ui/src/app/jobs/[jobID]/page.tsx @@ -1,27 +1,47 @@ 'use client'; -import { useMemo, useState, use } from 'react'; +import { useState, use } from 'react'; import { FaChevronLeft } from 'react-icons/fa'; import { Button } from '@headlessui/react'; import { TopBar, MainContent } from '@/components/layout'; import useJob from '@/hooks/useJob'; -import { startJob, stopJob } from '@/utils/jobs'; -import SampleImages from '@/components/SampleImages'; +import SampleImages, {SampleImagesMenu} from '@/components/SampleImages'; import JobOverview from '@/components/JobOverview'; -import { JobConfig } from '@/types'; import { redirect } from 'next/navigation'; import JobActionBar from '@/components/JobActionBar'; +import JobConfigViewer from '@/components/JobConfigViewer'; +import { Job } from '@prisma/client'; -type PageKey = 'overview' | 'samples'; +type PageKey = 'overview' | 'samples' | 'config'; interface Page { name: string; value: PageKey; + component: React.ComponentType<{ job: Job }>; + menuItem?: React.ComponentType<{ job?: Job | null }> | null; + mainCss?: string; } const pages: Page[] = [ - { name: 'Overview', value: 'overview' }, - { name: 'Samples', value: 'samples' }, + { + name: 'Overview', + value: 'overview', + component: JobOverview, + mainCss: 'pt-24', + }, + { + name: 'Samples', + value: 'samples', + component: SampleImages, + menuItem: SampleImagesMenu, + mainCss: 'pt-24', + }, + { + name: 'Config File', + value: 'config', + component: JobConfigViewer, + mainCss: 'pt-[80px] px-0 pb-0', + }, ]; export default function JobPage({ params }: { params: { jobID: string } }) { @@ -30,6 +50,8 @@ export default function JobPage({ params }: { params: { jobID: string } }) { const { job, status, refreshJob } = useJob(jobID, 5000); const [pageKey, setPageKey] = useState('overview'); + const page = pages.find(p => p.value === pageKey); + return ( <> {/* Fixed top bar */} @@ -54,13 +76,15 @@ export default function JobPage({ params }: { params: { jobID: string } }) { /> )} - + page.value === pageKey)?.mainCss}> {status === 'loading' && job == null &&

Loading...

} {status === 'error' && job == null &&

Error fetching job

} {job && ( <> - {pageKey === 'overview' && } - {pageKey === 'samples' && } + {pages.map(page => { + const Component = page.component; + return page.value === pageKey ? : null; + })} )}
@@ -74,6 +98,15 @@ export default function JobPage({ params }: { params: { jobID: string } }) { {page.name} ))} + { + page?.menuItem && ( + <> +
+
+ + + ) + }
); diff --git a/ui/src/app/jobs/new/AdvancedJob.tsx b/ui/src/app/jobs/new/AdvancedJob.tsx index 6a0d4388c..66938384d 100644 --- a/ui/src/app/jobs/new/AdvancedJob.tsx +++ b/ui/src/app/jobs/new/AdvancedJob.tsx @@ -108,7 +108,7 @@ export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props // We have to ensure certain things are always set try { - parsed.config.process[0].type = 'ui_trainer'; + // parsed.config.process[0].type = 'ui_trainer'; parsed.config.process[0].sqlite_db_path = './aitk_db.db'; parsed.config.process[0].training_folder = settings.TRAINING_FOLDER; parsed.config.process[0].device = 'cuda'; diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index c9c4d0f89..4ae5bcc9a 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -1,6 +1,13 @@ 'use client'; import { useMemo } from 'react'; -import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options'; +import { + modelArchs, + ModelArch, + groupedModelOptions, + quantizationOptions, + defaultQtype, + jobTypeOptions, +} from './options'; import { defaultDatasetConfig } from './jobConfig'; import { GroupedSelectOption, JobConfig, SelectOption } from '@/types'; import { objectCopy } from '@/utils/basic'; @@ -8,6 +15,9 @@ import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/comp import Card from '@/components/Card'; import { X } from 'lucide-react'; import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal'; +import SampleControlImage from '@/components/SampleControlImage'; +import { FlipHorizontal2, FlipVertical2 } from 'lucide-react'; +import { handleModelArchChange } from './utils'; type Props = { jobConfig: JobConfig; @@ -38,6 +48,21 @@ export default function SimpleJob({ return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch; }, [jobConfig.config.process[0].model.arch]); + const jobType = useMemo(() => { + return jobTypeOptions.find(j => j.value === jobConfig.config.process[0].type); + }, [jobConfig.config.process[0].type]); + + const disableSections = useMemo(() => { + let sections: string[] = []; + if (modelArch?.disableSections) { + sections = sections.concat(modelArch.disableSections); + } + if (jobType?.disableSections) { + sections = sections.concat(jobType.disableSections); + } + return sections; + }, [modelArch, jobType]); + const isVideoModel = !!(modelArch?.group === 'video'); const numTopCards = useMemo(() => { @@ -45,12 +70,14 @@ export default function SimpleJob({ if (modelArch?.additionalSections?.includes('model.multistage')) { count += 1; // add multistage card } - if (!modelArch?.disableSections?.includes('model.quantize')) { + if (!disableSections.includes('model.quantize')) { count += 1; // add quantization card } + if (!disableSections.includes('slider')) { + count += 1; // add slider card + } return count; - - }, [modelArch]); + }, [modelArch, disableSections]); let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6'; @@ -61,6 +88,20 @@ export default function SimpleJob({ topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6'; } + const numTrainingCols = useMemo(() => { + let count = 4; + if (!disableSections.includes('train.diff_output_preservation')) { + count += 1; + } + return count; + }, [disableSections]); + + let trainingBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6'; + + if (numTrainingCols == 5) { + trainingBarClass = 'grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6'; + } + const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => { const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0; if (!hasARA) { @@ -77,7 +118,7 @@ export default function SimpleJob({ let ARAs: SelectOption[] = []; if (modelArch.accuracyRecoveryAdapters) { for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) { - ARAs.push({ value, label }); + ARAs.push({ value, label }); } } if (ARAs.length > 0) { @@ -123,19 +164,21 @@ export default function SimpleJob({ onChange={value => setGpuIDs(value)} options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))} /> - { - if (value?.trim() === '') { - value = null; - } - setJobConfig(value, 'config.process[0].trigger_word'); - }} - placeholder="" - required - /> + {disableSections.includes('trigger_word') ? null : ( + { + if (value?.trim() === '') { + value = null; + } + setJobConfig(value, 'config.process[0].trigger_word'); + }} + placeholder="" + required + /> + )} {/* Model Configuration Section */} @@ -144,58 +187,7 @@ export default function SimpleJob({ label="Model Architecture" value={jobConfig.config.process[0].model.arch} onChange={value => { - const currentArch = modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch); - if (!currentArch || currentArch.name === value) { - return; - } - // update the defaults when a model is selected - const newArch = modelArchs.find(model => model.name === value); - - // update vram setting - if (!newArch?.additionalSections?.includes('model.low_vram')) { - setJobConfig(false, 'config.process[0].model.low_vram'); - } - - // revert defaults from previous model - for (const key in currentArch.defaults) { - setJobConfig(currentArch.defaults[key][1], key); - } - - if (newArch?.defaults) { - for (const key in newArch.defaults) { - setJobConfig(newArch.defaults[key][0], key); - } - } - // set new model - setJobConfig(value, 'config.process[0].model.arch'); - - // update datasets - const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false; - const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false; - const controls = newArch?.controls ?? []; - const datasets = jobConfig.config.process[0].datasets.map(dataset => { - const newDataset = objectCopy(dataset); - newDataset.controls = controls; - if (!hasControlPath) { - newDataset.control_path = null; // reset control path if not applicable - } - if (!hasNumFrames) { - newDataset.num_frames = 1; // reset num_frames if not applicable - } - return newDataset; - }); - setJobConfig(datasets, 'config.process[0].datasets'); - - // update samples - const hasSampleCtrlImg = newArch?.additionalSections?.includes('sample.ctrl_img') || false; - const samples = jobConfig.config.process[0].sample.samples.map(sample => { - const newSample = objectCopy(sample); - if (!hasSampleCtrlImg) { - delete newSample.ctrl_img; // remove ctrl_img if not applicable - } - return newSample; - }); - setJobConfig(samples, 'config.process[0].sample.samples'); + handleModelArchChange(jobConfig.config.process[0].model.arch, value, jobConfig, setJobConfig); }} options={groupedModelOptions} /> @@ -222,7 +214,7 @@ export default function SimpleJob({ )} - {modelArch?.disableSections?.includes('model.quantize') ? null : ( + {disableSections.includes('model.quantize') ? null : ( setJobConfig(value, 'config.process[0].train.switch_boundary_every')} - placeholder="eg. 1" - docKey={'train.switch_boundary_every'} - min={1} - required - /> + label="Switch Every" + value={jobConfig.config.process[0].train.switch_boundary_every} + onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')} + placeholder="eg. 1" + docKey={'train.switch_boundary_every'} + min={1} + required + /> )} @@ -318,7 +310,7 @@ export default function SimpleJob({ max={1024} required /> - {modelArch?.disableSections?.includes('network.conv') ? null : ( + {disableSections.includes('network.conv') ? null : ( )} + {!disableSections.includes('slider') && ( + + setJobConfig(value, 'config.process[0].slider.target_class')} + placeholder="eg. person" + /> + setJobConfig(value, 'config.process[0].slider.positive_prompt')} + placeholder="eg. person who is happy" + /> + setJobConfig(value, 'config.process[0].slider.negative_prompt')} + placeholder="eg. person who is sad" + /> + setJobConfig(value, 'config.process[0].slider.anchor_class')} + placeholder="" + /> + + )}
-
+
- {modelArch?.disableSections?.includes('train.timestep_type') ? null : ( + {disableSections.includes('train.timestep_type') ? null : ( setJobConfig(value, 'config.process[0].train.timestep_type')} options={[ { value: 'sigmoid', label: 'Sigmoid' }, @@ -451,13 +475,15 @@ export default function SimpleJob({ ]} /> setJobConfig(value, 'config.process[0].train.noise_scheduler')} + value={jobConfig.config.process[0].train.loss_type} + onChange={value => setJobConfig(value, 'config.process[0].train.loss_type')} options={[ - { value: 'flowmatch', label: 'FlowMatch' }, - { value: 'ddpm', label: 'DDPM' }, + { value: 'mse', label: 'Mean Squared Error' }, + { value: 'mae', label: 'Mean Absolute Error' }, + { value: 'wavelet', label: 'Wavelet' }, + { value: 'stepped', label: 'Stepped Recovery' }, ]} />
@@ -482,17 +508,19 @@ export default function SimpleJob({ )} - { - setJobConfig(value, 'config.process[0].train.unload_text_encoder'); - if (value) { - setJobConfig(false, 'config.process[0].train.cache_text_embeddings'); - } - }} - /> + {!disableSections.includes('train.unload_text_encoder') && ( + { + setJobConfig(value, 'config.process[0].train.unload_text_encoder'); + if (value) { + setJobConfig(false, 'config.process[0].train.cache_text_embeddings'); + } + }} + /> + )}
- - setJobConfig(value, 'config.process[0].train.diff_output_preservation')} - /> - - {jobConfig.config.process[0].train.diff_output_preservation && ( + {disableSections.includes('train.diff_output_preservation') ? null : ( <> - - setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier') - } - placeholder="eg. 1.0" - min={0} - /> - setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')} - placeholder="eg. woman" - /> + + setJobConfig(value, 'config.process[0].train.diff_output_preservation')} + /> + + {jobConfig.config.process[0].train.diff_output_preservation && ( + <> + + setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier') + } + placeholder="eg. 1.0" + min={0} + /> + + setJobConfig(value, 'config.process[0].train.diff_output_preservation_class') + } + placeholder="eg. woman" + /> + + )} )}
@@ -561,7 +595,7 @@ export default function SimpleJob({
setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)} options={datasetOptions} @@ -578,6 +612,49 @@ export default function SimpleJob({ options={[{ value: '', label: <>  }, ...datasetOptions]} /> )} + {modelArch?.additionalSections?.includes('datasets.multi_control_paths') && ( + <> + + setJobConfig( + value == '' ? null : value, + `config.process[0].datasets[${i}].control_path_1`, + ) + } + options={[{ value: '', label: <>  }, ...datasetOptions]} + /> + + setJobConfig( + value == '' ? null : value, + `config.process[0].datasets[${i}].control_path_2`, + ) + } + options={[{ value: '', label: <>  }, ...datasetOptions]} + /> + + setJobConfig( + value == '' ? null : value, + `config.process[0].datasets[${i}].control_path_3`, + ) + } + options={[{ value: '', label: <>  }, ...datasetOptions]} + /> + + )} )} + + + Flip X + + } + checked={dataset.flip_x || false} + onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_x`)} + /> + + Flip Y + + } + checked={dataset.flip_y || false} + onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_y`)} + /> +
@@ -796,7 +893,28 @@ export default function SimpleJob({ label="Skip First Sample" className="pt-4" checked={jobConfig.config.process[0].train.skip_first_sample || false} - onChange={value => setJobConfig(value, 'config.process[0].train.skip_first_sample')} + onChange={value => { + setJobConfig(value, 'config.process[0].train.skip_first_sample'); + // cannot do both, so disable the other + if (value) { + setJobConfig(false, 'config.process[0].train.force_first_sample'); + } + }} + /> +
+
+ { + setJobConfig(value, 'config.process[0].train.force_first_sample'); + // cannot do both, so disable the other + if (value) { + setJobConfig(false, 'config.process[0].train.skip_first_sample'); + } + }} />
@@ -804,7 +922,13 @@ export default function SimpleJob({ label="Disable Sampling" className="pt-1" checked={jobConfig.config.process[0].train.disable_sampling || false} - onChange={value => setJobConfig(value, 'config.process[0].train.disable_sampling')} + onChange={value => { + setJobConfig(value, 'config.process[0].train.disable_sampling'); + // cannot do both, so disable the other + if (value) { + setJobConfig(false, 'config.process[0].train.force_first_sample'); + } + }} />
@@ -826,31 +950,151 @@ export default function SimpleJob({ placeholder="Enter prompt" required /> +
+ { + // remove any non-numeric characters + value = value.replace(/\D/g, ''); + if (value === '') { + // remove the key from the config if empty + let newConfig = objectCopy(jobConfig); + if (newConfig.config.process[0].sample.samples[i]) { + delete newConfig.config.process[0].sample.samples[i].width; + setJobConfig( + newConfig.config.process[0].sample.samples, + 'config.process[0].sample.samples', + ); + } + } else { + const intValue = parseInt(value); + if (!isNaN(intValue)) { + setJobConfig(intValue, `config.process[0].sample.samples[${i}].width`); + } else { + console.warn('Invalid width value:', value); + } + } + }} + placeholder={`${jobConfig.config.process[0].sample.width} (default)`} + /> + { + // remove any non-numeric characters + value = value.replace(/\D/g, ''); + if (value === '') { + // remove the key from the config if empty + let newConfig = objectCopy(jobConfig); + if (newConfig.config.process[0].sample.samples[i]) { + delete newConfig.config.process[0].sample.samples[i].height; + setJobConfig( + newConfig.config.process[0].sample.samples, + 'config.process[0].sample.samples', + ); + } + } else { + const intValue = parseInt(value); + if (!isNaN(intValue)) { + setJobConfig(intValue, `config.process[0].sample.samples[${i}].height`); + } else { + console.warn('Invalid height value:', value); + } + } + }} + placeholder={`${jobConfig.config.process[0].sample.height} (default)`} + /> + { + // remove any non-numeric characters + value = value.replace(/\D/g, ''); + if (value === '') { + // remove the key from the config if empty + let newConfig = objectCopy(jobConfig); + if (newConfig.config.process[0].sample.samples[i]) { + delete newConfig.config.process[0].sample.samples[i].seed; + setJobConfig( + newConfig.config.process[0].sample.samples, + 'config.process[0].sample.samples', + ); + } + } else { + const intValue = parseInt(value); + if (!isNaN(intValue)) { + setJobConfig(intValue, `config.process[0].sample.samples[${i}].seed`); + } else { + console.warn('Invalid seed value:', value); + } + } + }} + placeholder={`${jobConfig.config.process[0].sample.walk_seed ? jobConfig.config.process[0].sample.seed + i : jobConfig.config.process[0].sample.seed} (default)`} + /> + { + // remove any non-numeric, - or . characters + value = value.replace(/[^0-9.-]/g, ''); + if (value === '') { + // remove the key from the config if empty + let newConfig = objectCopy(jobConfig); + if (newConfig.config.process[0].sample.samples[i]) { + delete newConfig.config.process[0].sample.samples[i].network_multiplier; + setJobConfig( + newConfig.config.process[0].sample.samples, + 'config.process[0].sample.samples', + ); + } + } else { + // set it as a string + setJobConfig(value, `config.process[0].sample.samples[${i}].network_multiplier`); + return; + } + }} + placeholder={`1.0 (default)`} + /> +
- + {modelArch?.additionalSections?.includes('datasets.multi_control_paths') && ( + +
+ {['ctrl_img_1', 'ctrl_img_2', 'ctrl_img_3'].map((ctrlKey, ctrl_idx) => ( + { + if (!imagePath) { + let newSamples = objectCopy(jobConfig.config.process[0].sample.samples); + delete newSamples[i][ctrlKey as keyof typeof sample]; + setJobConfig(newSamples, 'config.process[0].sample.samples'); + } else { + setJobConfig(imagePath, `config.process[0].sample.samples[${i}].${ctrlKey}`); + } + }} + /> + ))} +
+
+ )} {modelArch?.additionalSections?.includes('sample.ctrl_img') && ( -
{ - openAddImageModal(imagePath => { - console.log('Selected image path:', imagePath); - if (!imagePath) return; + { + if (!imagePath) { + let newSamples = objectCopy(jobConfig.config.process[0].sample.samples); + delete newSamples[i].ctrl_img; + setJobConfig(newSamples, 'config.process[0].sample.samples'); + } else { setJobConfig(imagePath, `config.process[0].sample.samples[${i}].ctrl_img`); - }); + } }} - > - {!sample.ctrl_img && ( -
Add Control Image
- )} -
+ /> )}
diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index 8b0c28a4b..54f34ca97 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -1,8 +1,7 @@ -import { JobConfig, DatasetConfig } from '@/types'; +import { JobConfig, DatasetConfig, SliderConfig } from '@/types'; export const defaultDatasetConfig: DatasetConfig = { folder_path: '/path/to/images/folder', - control_path: null, mask_path: null, mask_min_value: 0.1, default_caption: '', @@ -16,6 +15,17 @@ export const defaultDatasetConfig: DatasetConfig = { shrink_video_to_frames: true, num_frames: 1, do_i2v: true, + flip_x: false, + flip_y: false, +}; + +export const defaultSliderConfig: SliderConfig = { + guidance_strength: 3.0, + anchor_strength: 1.0, + positive_prompt: 'person who is happy', + negative_prompt: 'person who is sad', + target_class: 'person', + anchor_class: "", }; export const defaultJobConfig: JobConfig = { @@ -24,7 +34,7 @@ export const defaultJobConfig: JobConfig = { name: 'my_first_lora_v1', process: [ { - type: 'ui_trainer', + type: 'diffusion_trainer', training_folder: 'output', sqlite_db_path: './aitk_db.db', device: 'cuda', @@ -73,12 +83,14 @@ export const defaultJobConfig: JobConfig = { ema_decay: 0.99, }, skip_first_sample: false, + force_first_sample: false, disable_sampling: false, dtype: 'bf16', diff_output_preservation: false, diff_output_preservation_multiplier: 1.0, diff_output_preservation_class: 'person', switch_boundary_every: 1, + loss_type: 'mse', }, model: { name_or_path: 'ostris/Flex.1-alpha', @@ -97,7 +109,7 @@ export const defaultJobConfig: JobConfig = { height: 1024, samples: [ { - prompt: 'woman with red hair, playing chess at the park, bomb going off in the background' + prompt: 'woman with red hair, playing chess at the park, bomb going off in the background', }, { prompt: 'a woman holding a coffee cup, in a beanie, sitting at a cafe', @@ -106,7 +118,8 @@ export const defaultJobConfig: JobConfig = { prompt: 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini', }, { - prompt: 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background', + prompt: + 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background', }, { prompt: 'a bear building a log cabin in the snow covered mountains', @@ -118,13 +131,15 @@ export const defaultJobConfig: JobConfig = { prompt: 'hipster man with a beard, building a chair, in a wood shop', }, { - prompt: 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop', + prompt: + 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop', }, { prompt: "a man holding a sign that says, 'this is a sign'", }, { - prompt: 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle', + prompt: + 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle', }, ], neg: '', @@ -161,5 +176,10 @@ export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => { jobConfig.config.process[0].sample.samples = newSamples; delete jobConfig.config.process[0].sample.prompts; } + + // upgrade job from ui_trainer to diffusion_trainer + if (jobConfig?.config?.process && jobConfig.config.process[0]?.type === 'ui_trainer') { + jobConfig.config.process[0].type = 'diffusion_trainer'; + } return jobConfig; }; diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 71fdc9d8e..247245845 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -1,12 +1,23 @@ -import { GroupedSelectOption, SelectOption } from '@/types'; +import { GroupedSelectOption, SelectOption, JobConfig } from '@/types'; +import { defaultSliderConfig } from './jobConfig'; type Control = 'depth' | 'line' | 'pose' | 'inpaint'; -type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv'; +type DisableableSections = + | 'model.quantize' + | 'train.timestep_type' + | 'network.conv' + | 'trigger_word' + | 'train.diff_output_preservation' + | 'train.unload_text_encoder' + | 'slider'; + type AdditionalSections = | 'datasets.control_path' + | 'datasets.multi_control_paths' | 'datasets.do_i2v' | 'sample.ctrl_img' + | 'sample.multi_ctrl_imgs' | 'datasets.num_frames' | 'model.multistage' | 'model.low_vram'; @@ -327,6 +338,28 @@ export const modelArchs: ModelArch[] = [ '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors', }, }, + { + name: 'qwen_image_edit_plus', + label: 'Qwen-Image-Edit-2509', + group: 'instruction', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit-2509', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.low_vram': [true, false], + 'config.process[0].train.unload_text_encoder': [false, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], + 'config.process[0].model.qtype': ['qfloat8', 'qfloat8'], + }, + disableSections: ['network.conv', 'train.unload_text_encoder'], + additionalSections: ['datasets.multi_control_paths', 'sample.multi_ctrl_imgs', 'model.low_vram'], + accuracyRecoveryAdapters: { + '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors', + }, + }, { name: 'hidream', label: 'HiDream', @@ -344,6 +377,9 @@ export const modelArchs: ModelArch[] = [ }, disableSections: ['network.conv'], additionalSections: ['model.low_vram'], + accuracyRecoveryAdapters: { + '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/hidream_i1_full_torchao_uint3.safetensors', + }, }, { name: 'hidream_e1', @@ -439,3 +475,33 @@ export const quantizationOptions: SelectOption[] = [ ]; export const defaultQtype = 'qfloat8'; + +interface JobTypeOption extends SelectOption { + disableSections?: DisableableSections[]; + processSections?: string[]; + onActivate?: (config: JobConfig) => JobConfig; + onDeactivate?: (config: JobConfig) => JobConfig; +} + +export const jobTypeOptions: JobTypeOption[] = [ + { + value: 'diffusion_trainer', + label: 'LoRA Trainer', + disableSections: ['slider'], + }, + { + value: 'concept_slider', + label: 'Concept Slider', + disableSections: ['trigger_word', 'train.diff_output_preservation'], + onActivate: (config: JobConfig) => { + // add default slider config + config.config.process[0].slider = { ...defaultSliderConfig }; + return config; + }, + onDeactivate: (config: JobConfig) => { + // remove slider config + delete config.config.process[0].slider; + return config; + }, + }, +]; diff --git a/ui/src/app/jobs/new/page.tsx b/ui/src/app/jobs/new/page.tsx index fb2b85465..b57495d87 100644 --- a/ui/src/app/jobs/new/page.tsx +++ b/ui/src/app/jobs/new/page.tsx @@ -3,6 +3,7 @@ import { useEffect, useState } from 'react'; import { useSearchParams, useRouter } from 'next/navigation'; import { defaultJobConfig, defaultDatasetConfig, migrateJobConfig } from './jobConfig'; +import { jobTypeOptions } from './options'; import { JobConfig } from '@/types'; import { objectCopy } from '@/utils/basic'; import { useNestedState } from '@/utils/hooks'; @@ -144,6 +145,38 @@ export default function TrainingForm() {
)} + {!showAdvancedView && ( + <> +
+ { + // undo current job type changes + const currentOption = jobTypeOptions.find( + option => option.value === jobConfig?.config.process[0].type, + ); + if (currentOption && currentOption.onDeactivate) { + setJobConfig(currentOption.onDeactivate(objectCopy(jobConfig))); + } + const option = jobTypeOptions.find(option => option.value === value); + if (option) { + if (option.onActivate) { + setJobConfig(option.onActivate(objectCopy(jobConfig))); + } + jobTypeOptions.forEach(opt => { + if (opt.value !== option.value && opt.onDeactivate) { + setJobConfig(opt.onDeactivate(objectCopy(jobConfig))); + } + }); + } + setJobConfig(value, 'config.process[0].type'); + }} + options={jobTypeOptions} + /> +
+
+ + )}
+
+ + + + + + +
{ + let message = `Are you sure you want to mark this job as stopped? This will set the job status to 'stopped' if the status is hung. Only do this if you are 100% sure the job is stopped. This will NOT stop the job.`; + openConfirm({ + title: 'Mark Job as Stopped', + message: message, + type: 'warning', + confirmText: 'Mark as Stopped', + onConfirm: async () => { + await markJobAsStopped(job.id); + onRefresh && onRefresh(); + }, + }); + }} + > + Mark as Stopped +
+
+
+
); } diff --git a/ui/src/components/JobConfigViewer.tsx b/ui/src/components/JobConfigViewer.tsx new file mode 100644 index 000000000..ce1df9c8b --- /dev/null +++ b/ui/src/components/JobConfigViewer.tsx @@ -0,0 +1,49 @@ +'use client'; +import { useEffect, useState } from 'react'; +import YAML from 'yaml'; +import Editor from '@monaco-editor/react'; + +import { Job } from '@prisma/client'; + +interface Props { + job: Job; +} + +const yamlConfig: YAML.DocumentOptions & + YAML.SchemaOptions & + YAML.ParseOptions & + YAML.CreateNodeOptions & + YAML.ToStringOptions = { + indent: 2, + lineWidth: 999999999999, + defaultStringType: 'QUOTE_DOUBLE', + defaultKeyType: 'PLAIN', + directives: true, +}; + +export default function JobConfigViewer({ job }: Props) { + const [editorValue, setEditorValue] = useState(''); + useEffect(() => { + if (job?.job_config) { + const yamlContent = YAML.stringify(JSON.parse(job.job_config), yamlConfig); + setEditorValue(yamlContent); + } + }, [job]); + return ( + <> + + + ); +} diff --git a/ui/src/components/SampleControlImage.tsx b/ui/src/components/SampleControlImage.tsx new file mode 100644 index 000000000..b10b2cef4 --- /dev/null +++ b/ui/src/components/SampleControlImage.tsx @@ -0,0 +1,206 @@ +'use client'; + +import React, { useCallback, useMemo, useRef, useState } from 'react'; +import classNames from 'classnames'; +import { useDropzone } from 'react-dropzone'; +import { FaUpload, FaImage, FaTimes } from 'react-icons/fa'; +import { apiClient } from '@/utils/api'; +import type { AxiosProgressEvent } from 'axios'; + +interface Props { + src: string | null | undefined; + className?: string; + instruction?: string; + onNewImageSelected: (imagePath: string | null) => void; +} + +export default function SampleControlImage({ + src, + className, + instruction = 'Add Control Image', + onNewImageSelected, +}: Props) { + const [isUploading, setIsUploading] = useState(false); + const [uploadProgress, setUploadProgress] = useState(0); + const [localPreview, setLocalPreview] = useState(null); + const fileInputRef = useRef(null); + + const backgroundUrl = useMemo(() => { + if (localPreview) return localPreview; + if (src) return `/api/img/${encodeURIComponent(src)}`; + return null; + }, [src, localPreview]); + + const handleUpload = useCallback( + async (file: File) => { + if (!file) return; + setIsUploading(true); + setUploadProgress(0); + + const objectUrl = URL.createObjectURL(file); + setLocalPreview(objectUrl); + + const formData = new FormData(); + formData.append('files', file); + + try { + const resp = await apiClient.post(`/api/img/upload`, formData, { + headers: { 'Content-Type': 'multipart/form-data' }, + onUploadProgress: (evt: AxiosProgressEvent) => { + const total = evt.total ?? 100; + const loaded = evt.loaded ?? 0; + setUploadProgress(Math.round((loaded * 100) / total)); + }, + timeout: 0, + }); + + const uploaded = resp?.data?.files?.[0] ?? null; + onNewImageSelected(uploaded); + } catch (err) { + console.error('Upload failed:', err); + setLocalPreview(null); + } finally { + setIsUploading(false); + setUploadProgress(0); + URL.revokeObjectURL(objectUrl); + if (fileInputRef.current) fileInputRef.current.value = ''; + } + }, + [onNewImageSelected], + ); + + const onDrop = useCallback( + (acceptedFiles: File[]) => { + if (acceptedFiles.length === 0) return; + handleUpload(acceptedFiles[0]); + }, + [handleUpload], + ); + + const clearImage = useCallback( + (e?: React.MouseEvent) => { + console.log('clearImage'); + if (e) { + e.stopPropagation(); + e.preventDefault(); + } + setLocalPreview(null); + onNewImageSelected(null); + if (fileInputRef.current) fileInputRef.current.value = ''; + }, + [onNewImageSelected], + ); + + // Drag & drop only; click handled via our own hidden input + const { getRootProps, isDragActive } = useDropzone({ + onDrop, + accept: { 'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'] }, + multiple: false, + noClick: true, + noKeyboard: true, + }); + + const rootProps = getRootProps(); + + return ( +
!isUploading && fileInputRef.current?.click()} + > + {/* Hidden input for click-to-open */} + { + const file = e.currentTarget.files?.[0]; + if (file) handleUpload(file); + }} + /> + + {/* Empty state — centered */} + {!backgroundUrl && ( +
+ +
{instruction}
+
Click or drop
+
+ )} + + {/* Existing image overlays */} + {backgroundUrl && !isUploading && ( + <> +
+
+ + Replace +
+
+ + {/* Clear (X) button */} + + + )} + + {/* Uploading overlay */} + {isUploading && ( +
+
+
+
+
+
Uploading… {uploadProgress}%
+
+
+ )} +
+ ); +} diff --git a/ui/src/components/SampleImageCard.tsx b/ui/src/components/SampleImageCard.tsx index 43a41119c..7f01bc8f9 100644 --- a/ui/src/components/SampleImageCard.tsx +++ b/ui/src/components/SampleImageCard.tsx @@ -1,5 +1,4 @@ import React, { useRef, useEffect, useState, ReactNode } from 'react'; -import { sampleImageModalState } from '@/components/SampleImageModal'; import { isVideo } from '@/utils/basic'; interface SampleImageCardProps { @@ -10,6 +9,11 @@ interface SampleImageCardProps { children?: ReactNode; className?: string; onDelete?: () => void; + onClick?: () => void; + /** pass your scroll container element (e.g. containerRef.current) */ + observerRoot?: Element | null; + /** optional: tweak pre-load buffer */ + rootMargin?: string; // default '200px 0px' } const SampleImageCard: React.FC = ({ @@ -19,71 +23,84 @@ const SampleImageCard: React.FC = ({ sampleImages, children, className = '', + onClick = () => {}, + observerRoot = null, + rootMargin = '200px 0px', }) => { const cardRef = useRef(null); - const [isVisible, setIsVisible] = useState(false); - const [loaded, setLoaded] = useState(false); + const videoRef = useRef(null); + const [isVisible, setIsVisible] = useState(false); + const [loaded, setLoaded] = useState(false); + // Observe both enter and exit useEffect(() => { - // Create intersection observer to check visibility + const el = cardRef.current; + if (!el) return; + const observer = new IntersectionObserver( entries => { - if (entries[0].isIntersecting) { - setIsVisible(true); - observer.disconnect(); + for (const entry of entries) { + if (entry.target === el) { + setIsVisible(entry.isIntersecting); + } } }, - { threshold: 0.1 }, + { + root: observerRoot ?? null, + threshold: 0.01, + rootMargin, + }, ); - if (cardRef.current) { - observer.observe(cardRef.current); - } + observer.observe(el); + return () => observer.disconnect(); + }, [observerRoot, rootMargin]); - return () => { - observer.disconnect(); - }; - }, []); + // Pause video when leaving viewport + useEffect(() => { + if (!isVideo(imageUrl)) return; + const v = videoRef.current; + if (!v) return; + if (!isVisible && !v.paused) { + try { + v.pause(); + } catch {} + } + }, [isVisible, imageUrl]); - const handleLoad = (): void => { - setLoaded(true); - }; - console.log('imgurl',imageUrl.toLowerCase().slice(-4)) + const handleLoad = () => setLoaded(true); return (
- {/* Square image container */} -
sampleImageModalState.set({ imgPath: imageUrl, numSamples, sampleImages })} - > +
- {isVisible && ( - <> - {isVideo(imageUrl) ? ( -
diff --git a/ui/src/components/SampleImageModal.tsx b/ui/src/components/SampleImageModal.tsx deleted file mode 100644 index a98734e07..000000000 --- a/ui/src/components/SampleImageModal.tsx +++ /dev/null @@ -1,190 +0,0 @@ -'use client'; -import { useState, useEffect, useMemo } from 'react'; -import { createGlobalState } from 'react-global-hooks'; -import { Dialog, DialogBackdrop, DialogPanel } from '@headlessui/react'; - -export interface SampleImageModalState { - imgPath: string; - numSamples: number; - sampleImages: string[]; -} - -export const sampleImageModalState = createGlobalState(null); - -export const openSampleImage = (sampleImageProps: SampleImageModalState) => { - sampleImageModalState.set(sampleImageProps); -}; - -export default function SampleImageModal() { - const [imageModal, setImageModal] = sampleImageModalState.use(); - const [isOpen, setIsOpen] = useState(false); - - useEffect(() => { - if (imageModal) { - setIsOpen(true); - } - }, [imageModal]); - - useEffect(() => { - if (!isOpen) { - // use timeout to allow the dialog to close before resetting the state - setTimeout(() => { - setImageModal(null); - }, 500); - } - }, [isOpen]); - - const onCancel = () => { - setIsOpen(false); - }; - - const imgInfo = useMemo(() => { - const ii = { - filename: '', - step: 0, - promptIdx: 0, - }; - if (imageModal?.imgPath) { - const filename = imageModal.imgPath.split('/').pop(); - if (!filename) return ii; - // filename is ___. - ii.filename = filename as string; - const parts = filename - .split('.')[0] - .split('_') - .filter(p => p !== ''); - if (parts.length === 3) { - ii.step = parseInt(parts[1]); - ii.promptIdx = parseInt(parts[2]); - } - } - return ii; - }, [imageModal]); - - const handleArrowUp = () => { - if (!imageModal) return; - console.log('Arrow Up pressed'); - // Change image to same sample but up one step - const currentIdx = imageModal.sampleImages.findIndex(img => img === imageModal.imgPath); - if (currentIdx === -1) return; - const nextIdx = currentIdx - imageModal.numSamples; - if (nextIdx < 0) return; - openSampleImage({ - imgPath: imageModal.sampleImages[nextIdx], - numSamples: imageModal.numSamples, - sampleImages: imageModal.sampleImages, - }); - }; - - const handleArrowDown = () => { - if (!imageModal) return; - console.log('Arrow Down pressed'); - // Change image to same sample but down one step - const currentIdx = imageModal.sampleImages.findIndex(img => img === imageModal.imgPath); - if (currentIdx === -1) return; - const nextIdx = currentIdx + imageModal.numSamples; - if (nextIdx >= imageModal.sampleImages.length) return; - openSampleImage({ - imgPath: imageModal.sampleImages[nextIdx], - numSamples: imageModal.numSamples, - sampleImages: imageModal.sampleImages, - }); - }; - - const handleArrowLeft = () => { - if (!imageModal) return; - if (imgInfo.promptIdx === 0) return; - console.log('Arrow Left pressed'); - // go to previous sample - const currentIdx = imageModal.sampleImages.findIndex(img => img === imageModal.imgPath); - if (currentIdx === -1) return; - const minIdx = currentIdx - imgInfo.promptIdx; - const nextIdx = currentIdx - 1; - if (nextIdx < minIdx) return; - openSampleImage({ - imgPath: imageModal.sampleImages[nextIdx], - numSamples: imageModal.numSamples, - sampleImages: imageModal.sampleImages, - }); - }; - - const handleArrowRight = () => { - if (!imageModal) return; - console.log('Arrow Right pressed'); - // go to next sample - const currentIdx = imageModal.sampleImages.findIndex(img => img === imageModal.imgPath); - if (currentIdx === -1) return; - const stepMinIdx = currentIdx - imgInfo.promptIdx; - const maxIdx = stepMinIdx + imageModal.numSamples - 1; - const nextIdx = currentIdx + 1; - if (nextIdx > maxIdx) return; - if (nextIdx >= imageModal.sampleImages.length) return; - openSampleImage({ - imgPath: imageModal.sampleImages[nextIdx], - numSamples: imageModal.numSamples, - sampleImages: imageModal.sampleImages, - }); - }; - - // Handle keyboard events - useEffect(() => { - const handleKeyDown = (event: KeyboardEvent) => { - if (!isOpen) return; - - switch (event.key) { - case 'Escape': - onCancel(); - break; - case 'ArrowUp': - handleArrowUp(); - break; - case 'ArrowDown': - handleArrowDown(); - break; - case 'ArrowLeft': - handleArrowLeft(); - break; - case 'ArrowRight': - handleArrowRight(); - break; - default: - break; - } - }; - - window.addEventListener('keydown', handleKeyDown); - - return () => { - window.removeEventListener('keydown', handleKeyDown); - }; - }, [isOpen, imageModal, imgInfo]); - - return ( - - - -
-
- -
- {imageModal?.imgPath && ( - Sample Image - )} -
-
step: {imgInfo.step}
-
-
-
-
- ); -} diff --git a/ui/src/components/SampleImageViewer.tsx b/ui/src/components/SampleImageViewer.tsx new file mode 100644 index 000000000..b5fa1b016 --- /dev/null +++ b/ui/src/components/SampleImageViewer.tsx @@ -0,0 +1,283 @@ +'use client'; +import { useState, useEffect, useMemo, useCallback } from 'react'; +import { createPortal } from 'react-dom'; +import { Dialog, DialogBackdrop, DialogPanel } from '@headlessui/react'; +import { SampleConfig, SampleItem } from '@/types'; +import { Cog } from 'lucide-react'; +import { Menu, MenuButton, MenuItem, MenuItems } from '@headlessui/react'; +import { openConfirm } from './ConfirmModal'; +import { apiClient } from '@/utils/api'; + +interface Props { + imgPath: string | null; // current image path + numSamples: number; // number of samples per row + sampleImages: string[]; // all sample images + sampleConfig: SampleConfig | null; + onChange: (nextPath: string | null) => void; // parent setter + refreshSampleImages?: () => void; +} + +export default function SampleImageViewer({ + imgPath, + numSamples, + sampleImages, + sampleConfig, + onChange, + refreshSampleImages, +}: Props) { + const [mounted, setMounted] = useState(false); + const [isOpen, setIsOpen] = useState(Boolean(imgPath)); + + useEffect(() => setMounted(true), []); + + // open/close based on external value + useEffect(() => { + setIsOpen(Boolean(imgPath)); + }, [imgPath]); + + // after close, clear parent state post-transition + useEffect(() => { + if (!isOpen && imgPath) { + const t = setTimeout(() => onChange(null), 300); + return () => clearTimeout(t); + } + }, [isOpen, imgPath, onChange]); + + const onCancel = useCallback(() => setIsOpen(false), []); + + const imgInfo = useMemo(() => { + const ii = { filename: '', step: 0, promptIdx: 0 }; + if (imgPath) { + const filename = imgPath.split('/').pop(); + if (!filename) return ii; + ii.filename = filename; + const parts = filename + .split('.')[0] + .split('_') + .filter(p => p !== ''); + if (parts.length === 3) { + ii.step = parseInt(parts[1]); + ii.promptIdx = parseInt(parts[2]); + } + } + return ii; + }, [imgPath]); + + const setImageAtIndex = useCallback( + (idx: number) => { + if (idx < 0 || idx >= sampleImages.length) return; + onChange(sampleImages[idx]); + }, + [sampleImages, numSamples, onChange], + ); + + const currentIndex = useMemo(() => { + if (!imgPath) return -1; + return sampleImages.findIndex(img => img === imgPath); + }, [imgPath, sampleImages]); + + const handleArrowUp = useCallback(() => { + if (currentIndex === -1) return; + setImageAtIndex(currentIndex - numSamples); + }, [numSamples, currentIndex, setImageAtIndex]); + + const handleArrowDown = useCallback(() => { + if (currentIndex === -1) return; + setImageAtIndex(currentIndex + numSamples); + }, [numSamples, currentIndex, setImageAtIndex]); + + const handleArrowLeft = useCallback(() => { + if (currentIndex === -1) return; + if (imgInfo.promptIdx === 0) return; + const minIdx = currentIndex - imgInfo.promptIdx; + const nextIdx = currentIndex - 1; + if (nextIdx < minIdx) return; + setImageAtIndex(nextIdx); + }, [sampleImages, currentIndex, imgInfo.promptIdx, setImageAtIndex]); + + const handleArrowRight = useCallback(() => { + if (currentIndex === -1) return; + const stepMinIdx = currentIndex - imgInfo.promptIdx; + const maxIdx = stepMinIdx + numSamples - 1; + const nextIdx = currentIndex + 1; + if (nextIdx > maxIdx) return; + setImageAtIndex(nextIdx); + }, [sampleImages, currentIndex, imgInfo.promptIdx, setImageAtIndex]); + + const sampleItem = useMemo(() => { + if (!sampleConfig) return null; + if (imgInfo.promptIdx < 0) return null; + if (imgInfo.promptIdx >= sampleConfig.samples.length) return null; + return sampleConfig.samples[imgInfo.promptIdx]; + }, [sampleConfig, imgInfo.promptIdx]); + + const controlImages = useMemo(() => { + if (!imgPath) return []; + let controlImageArr: string[] = []; + if (sampleItem?.ctrl_img) { + // can be a an array of paths, or a single path + if (Array.isArray(sampleItem.ctrl_img)) { + controlImageArr = sampleItem.ctrl_img; + } else { + controlImageArr = [sampleItem.ctrl_img]; + } + } else if (sampleItem?.ctrl_img_1) { + controlImageArr.push(sampleItem.ctrl_img_1); + } + if (sampleItem?.ctrl_img_2) { + controlImageArr.push(sampleItem.ctrl_img_2); + } + if (sampleItem?.ctrl_img_3) { + controlImageArr.push(sampleItem.ctrl_img_3); + } + // filter out nulls + controlImageArr = controlImageArr.filter(ci => ci !== null && ci !== undefined && ci !== ''); + return controlImageArr; + }, [sampleItem, imgPath]); + + const seed = useMemo(() => { + if (!sampleItem) return '?'; + if (sampleItem.seed !== undefined) return sampleItem.seed; + if (sampleConfig?.walk_seed) { + return sampleConfig.seed + imgInfo.promptIdx; + } + return sampleConfig?.seed ?? '?'; + }, [sampleItem, sampleConfig]); + + // keyboard events while open + useEffect(() => { + const handleKeyDown = (event: KeyboardEvent) => { + if (!isOpen) return; + switch (event.key) { + case 'Escape': + onCancel(); + break; + case 'ArrowUp': + handleArrowUp(); + break; + case 'ArrowDown': + handleArrowDown(); + break; + case 'ArrowLeft': + handleArrowLeft(); + break; + case 'ArrowRight': + handleArrowRight(); + break; + default: + break; + } + }; + window.addEventListener('keydown', handleKeyDown); + return () => window.removeEventListener('keydown', handleKeyDown); + }, [isOpen, onCancel, handleArrowUp, handleArrowDown, handleArrowLeft, handleArrowRight]); + + if (!mounted) return null; + + return createPortal( + + +
+
+ +
+ {imgPath && ( + Sample Image + )} +
+ {/* # make full width */} +
+
+ {sampleItem?.prompt && ( +
+
+ Prompt: + {sampleItem.prompt} +
+
+ )} +
+ {controlImages.length > 0 && ( +
+ {controlImages.map((ci, idx) => ( + {`Control + ))} +
+ )} + +
+
+ Step: {imgInfo.step.toLocaleString()} +
+
+ Sample #: {imgInfo.promptIdx + 1} +
+
+ Seed: {seed} +
+
+
+
+ + + + + + +
{ + let message = `Are you sure you want to delete this sample? This action cannot be undone.`; + openConfirm({ + title: 'Delete Sample', + message: message, + type: 'warning', + confirmText: 'Delete', + onConfirm: () => { + apiClient + .post('/api/img/delete', { imgPath: imgPath }) + .then(() => { + console.log('Image deleted:', imgPath); + onChange(null); + if (refreshSampleImages) { + refreshSampleImages(); + } + }) + .catch(error => { + console.error('Error deleting image:', error); + }); + }, + }); + }} + > + Delete Sample +
+
+
+
+
+
+
+
+
, + document.body, + ); +} diff --git a/ui/src/components/SampleImages.tsx b/ui/src/components/SampleImages.tsx index 8f4b80d90..b55d42eb0 100644 --- a/ui/src/components/SampleImages.tsx +++ b/ui/src/components/SampleImages.tsx @@ -1,8 +1,66 @@ -import { useMemo } from 'react'; +import { useMemo, useState, useRef, useEffect } from 'react'; import useSampleImages from '@/hooks/useSampleImages'; import SampleImageCard from './SampleImageCard'; import { Job } from '@prisma/client'; import { JobConfig } from '@/types'; +import { LuImageOff, LuLoader, LuBan } from 'react-icons/lu'; +import { Button } from '@headlessui/react'; +import { FaDownload } from 'react-icons/fa'; +import { apiClient } from '@/utils/api'; +import classNames from 'classnames'; +import { FaCaretDown, FaCaretUp } from 'react-icons/fa'; +import SampleImageViewer from './SampleImageViewer'; + +interface SampleImagesMenuProps { + job?: Job | null; +} + +export const SampleImagesMenu = ({ job }: SampleImagesMenuProps) => { + const [isZipping, setIsZipping] = useState(false); + + const downloadZip = async () => { + if (isZipping) return; + setIsZipping(true); + + try { + const res = await apiClient.post('/api/zip', { + zipTarget: 'samples', + jobName: job?.name, + }); + + const zipPath = res.data.zipPath; // e.g. /mnt/Train2/out/ui/.../samples.zip + if (!zipPath) throw new Error('No zipPath in response'); + + const downloadPath = `/api/files/${encodeURIComponent(zipPath)}`; + const a = document.createElement('a'); + a.href = downloadPath; + // optional: suggest filename (browser may ignore if server sets Content-Disposition) + a.download = 'samples.zip'; + document.body.appendChild(a); + a.click(); + a.remove(); + } catch (err) { + console.error('Error downloading zip:', err); + } finally { + setIsZipping(false); + } + }; + return ( + + ); +}; interface SampleImagesProps { job: Job; @@ -10,19 +68,84 @@ interface SampleImagesProps { export default function SampleImages({ job }: SampleImagesProps) { const { sampleImages, status, refreshSampleImages } = useSampleImages(job.id, 5000); + const [selectedSamplePath, setSelectedSamplePath] = useState(null); + const containerRef = useRef(null); + const didFirstScroll = useRef(false); const numSamples = useMemo(() => { if (job?.job_config) { const jobConfig = JSON.parse(job.job_config) as JobConfig; const sampleConfig = jobConfig.config.process[0].sample; - if (sampleConfig.prompts) { - return sampleConfig.prompts.length; - } else { - return sampleConfig.samples.length; - } + const numPrompts = sampleConfig.prompts ? sampleConfig.prompts.length : 0; + const numSamples = sampleConfig.samples.length; + return Math.max(numPrompts, numSamples, 1); } return 10; }, [job]); + const scrollToBottom = () => { + if (containerRef.current) { + containerRef.current.scrollTo({ top: containerRef.current.scrollHeight, behavior: 'instant' }); + } + }; + + const scrollToTop = () => { + if (containerRef.current) { + containerRef.current.scrollTo({ top: 0, behavior: 'instant' }); + } + }; + + const PageInfoContent = useMemo(() => { + let icon = null; + let text = ''; + let subtitle = ''; + let showIt = false; + let bgColor = ''; + let textColor = ''; + let iconColor = ''; + + if (sampleImages.length > 0) return null; + + if (status == 'loading') { + icon = ; + text = 'Loading Samples'; + subtitle = 'Please wait while we fetch your samples...'; + showIt = true; + bgColor = 'bg-gray-50 dark:bg-gray-800/50'; + textColor = 'text-gray-900 dark:text-gray-100'; + iconColor = 'text-gray-500 dark:text-gray-400'; + } + if (status == 'error') { + icon = ; + text = 'Error Loading Samples'; + subtitle = 'There was a problem fetching the samples.'; + showIt = true; + bgColor = 'bg-red-50 dark:bg-red-950/20'; + textColor = 'text-red-900 dark:text-red-100'; + iconColor = 'text-red-600 dark:text-red-400'; + } + if (status == 'success' && sampleImages.length === 0) { + icon = ; + text = 'No Samples Found'; + subtitle = 'No samples have been generated yet'; + showIt = true; + bgColor = 'bg-gray-50 dark:bg-gray-800/50'; + textColor = 'text-gray-900 dark:text-gray-100'; + iconColor = 'text-gray-500 dark:text-gray-400'; + } + + if (!showIt) return null; + + return ( +
+
{icon}
+

{text}

+

{subtitle}

+
+ ); + }, [status, sampleImages.length]); + // Use direct Tailwind class without string interpolation // This way Tailwind can properly generate the class // I hate this, but it's the only way to make it work @@ -115,11 +238,28 @@ export default function SampleImages({ job }: SampleImagesProps) { } }, [numSamples]); + const sampleConfig = useMemo(() => { + if (job?.job_config) { + const jobConfig = JSON.parse(job.job_config) as JobConfig; + return jobConfig.config.process[0].sample; + } + return null; + }, [job]); + + // scroll to bottom on first load of samples + useEffect(() => { + if (status === 'success' && sampleImages.length > 0 && !didFirstScroll.current) { + didFirstScroll.current = true; + setTimeout(() => { + scrollToBottom(); + }, 100); + } + }, [status, sampleImages.length]); + return ( -
+
- {status === 'loading' && sampleImages.length === 0 &&

Loading...

} - {status === 'error' &&

Error fetching sample images

} + {PageInfoContent} {sampleImages && (
{sampleImages.map((sample: string) => ( @@ -129,11 +269,35 @@ export default function SampleImages({ job }: SampleImagesProps) { numSamples={numSamples} sampleImages={sampleImages} alt="Sample Image" + onClick={() => setSelectedSamplePath(sample)} + observerRoot={containerRef.current} /> ))}
)}
+ setSelectedSamplePath(setPath)} + sampleConfig={sampleConfig} + refreshSampleImages={refreshSampleImages} + /> +
+ +
+
+ +
); } diff --git a/ui/src/components/formInputs.tsx b/ui/src/components/formInputs.tsx index 9dacc336f..90f543db1 100644 --- a/ui/src/components/formInputs.tsx +++ b/ui/src/components/formInputs.tsx @@ -68,8 +68,8 @@ export const TextInput = forwardRef((props: Te TextInput.displayName = 'TextInput'; export interface NumberInputProps extends InputProps { - value: number; - onChange: (value: number) => void; + value: number | null; + onChange: (value: number | null) => void; min?: number; max?: number; } @@ -201,7 +201,7 @@ export const SelectInput = (props: SelectInputProps) => { }; export interface CheckboxProps { - label?: string; + label?: string | React.ReactNode; checked: boolean; onChange: (checked: boolean) => void; className?: string; diff --git a/ui/src/docs.tsx b/ui/src/docs.tsx index 3e8359eb1..b93ae3a22 100644 --- a/ui/src/docs.tsx +++ b/ui/src/docs.tsx @@ -53,10 +53,26 @@ const docs: { [key: string]: ConfigDoc } = { }, 'datasets.control_path': { title: 'Control Dataset', + description: ( + <> + The control dataset needs to have files that match the filenames of your training dataset. They should be + matching file pairs. These images are fed as control/input images during training. The control images will be + resized to match the training images. + + ), + }, + 'datasets.multi_control_paths': { + title: 'Multi Control Dataset', description: ( <> The control dataset needs to have files that match the filenames of your training dataset. They should be matching file pairs. These images are fed as control/input images during training. +
+
+ For multi control datasets, the controls will all be applied in the order they are listed. If the model does not + require the images to be the same aspect ratios, such as with Qwen/Qwen-Image-Edit-2509, then the control images + do not need to match the aspect size or aspect ratio of the target image and they will be automatically resized to + the ideal resolutions for the model / target images. ), }, @@ -90,6 +106,20 @@ const docs: { [key: string]: ConfigDoc } = { ), }, + 'datasets.flip': { + title: 'Flip X and Flip Y', + description: ( + <> + You can augment your dataset on the fly by flipping the x (horizontal) and/or y (vertical) axis. Flipping a single axis will effectively double your dataset. + It will result it training on normal images, and the flipped versions of the images. This can be very helpful, but keep in mind it can also + be destructive. There is no reason to train people upside down, and flipping a face can confuse the model as a person's right side does not + look identical to their left side. For text, obviously flipping text is not a good idea. +
+
+ Control images for a dataset will also be flipped to match the images, so they will always match on the pixel level. + + ), + }, 'train.unload_text_encoder': { title: 'Unload Text Encoder', description: ( @@ -141,6 +171,17 @@ const docs: { [key: string]: ConfigDoc } = { ), }, + 'train.force_first_sample': { + title: 'Force First Sample', + description: ( + <> + This option will force the trainer to generate samples when it starts. The trainer will normally only generate a first sample + when nothing has been trained yet, but will not do a first sample when resuming from an existing checkpoint. This option + forces a first sample every time the trainer is started. This can be useful if you have changed sample prompts and want to see + the new prompts right away. + + ), + }, }; export const getDoc = (key: string | null | undefined): ConfigDoc | null => { diff --git a/ui/src/types.ts b/ui/src/types.ts index d7da45b93..71048ca5c 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -83,10 +83,15 @@ export interface DatasetConfig { cache_latents_to_disk?: boolean; resolution: number[]; controls: string[]; - control_path: string | null; + control_path?: string | null; num_frames: number; shrink_video_to_frames: boolean; do_i2v: boolean; + flip_x: boolean; + flip_y: boolean; + control_path_1?: string | null; + control_path_2?: string | null; + control_path_3?: string | null; } export interface EMAConfig { @@ -115,11 +120,13 @@ export interface TrainConfig { weight_decay: number; }; skip_first_sample: boolean; + force_first_sample: boolean; disable_sampling: boolean; diff_output_preservation: boolean; diff_output_preservation_multiplier: number; diff_output_preservation_class: string; switch_boundary_every: number; + loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped'; } export interface QuantizeKwargsConfig { @@ -140,7 +147,7 @@ export interface ModelConfig { export interface SampleItem { prompt: string; - width?: number + width?: number; height?: number; neg?: string; seed?: number; @@ -150,6 +157,10 @@ export interface SampleItem { num_frames?: number; ctrl_img?: string | null; ctrl_idx?: number; + network_multiplier?: number; + ctrl_img_1?: string | null; + ctrl_img_2?: string | null; + ctrl_img_3?: string | null; } export interface SampleConfig { @@ -168,14 +179,24 @@ export interface SampleConfig { fps: number; } +export interface SliderConfig { + guidance_strength?: number; + anchor_strength?: number; + positive_prompt?: string; + negative_prompt?: string; + target_class?: string; + anchor_class?: string | null; +} + export interface ProcessConfig { - type: 'ui_trainer'; + type: string; sqlite_db_path?: string; training_folder: string; performance_log_every: number; trigger_word: string | null; device: string; network?: NetworkConfig; + slider?: SliderConfig; save: SaveConfig; datasets: DatasetConfig[]; train: TrainConfig; diff --git a/ui/src/utils/jobs.ts b/ui/src/utils/jobs.ts index 3a74870e5..ed656edce 100644 --- a/ui/src/utils/jobs.ts +++ b/ui/src/utils/jobs.ts @@ -50,6 +50,22 @@ export const deleteJob = (jobID: string) => { }); }; +export const markJobAsStopped = (jobID: string) => { + return new Promise((resolve, reject) => { + apiClient + .get(`/api/jobs/${jobID}/mark_stopped`) + .then(res => res.data) + .then(data => { + console.log('Job marked as stopped:', data); + resolve(); + }) + .catch(error => { + console.error('Error marking job as stopped:', error); + reject(error); + }); + }); +}; + export const getJobConfig = (job: Job) => { return JSON.parse(job.job_config) as JobConfig; }; diff --git a/version.py b/version.py index d613b95e0..af9f7a48b 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.5.5" \ No newline at end of file +VERSION = "0.6.0" \ No newline at end of file