diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py index b3f4af3ff..450d6fe1a 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -188,9 +188,15 @@ def __init__( self._wan_cache = None self.is_multistage = True + + # Detect if this is I2V or T2V model + self.is_i2v = 'i2v' in model_config.name_or_path.lower() + self.boundary_ratio = boundary_ratio_i2v if self.is_i2v else boundary_ratio_t2v + # multistage boundaries split the models up when sampling timesteps - # for wan 2.2 14b. the timesteps are 1000-875 for transformer 1 and 875-0 for transformer 2 - self.multistage_boundaries: List[float] = [0.875, 0.0] + # for wan 2.2 14b I2V: timesteps 1000-900 for transformer 1 and 900-0 for transformer 2 + # for wan 2.2 14b T2V: timesteps 1000-875 for transformer 1 and 875-0 for transformer 2 + self.multistage_boundaries: List[float] = [self.boundary_ratio, 0.0] self.train_high_noise = model_config.model_kwargs.get("train_high_noise", True) self.train_low_noise = model_config.model_kwargs.get("train_low_noise", True) @@ -339,7 +345,7 @@ def load_wan_transformer(self, transformer_path, subfolder=None): transformer_2=transformer_2, torch_dtype=self.torch_dtype, device=self.device_torch, - boundary_ratio=boundary_ratio_t2v, + boundary_ratio=self.boundary_ratio, low_vram=self.model_config.low_vram, ) @@ -363,8 +369,7 @@ def get_generation_pipeline(self): expand_timesteps=self._wan_expand_timesteps, device=self.device_torch, aggressive_offload=self.model_config.low_vram, - # todo detect if it is i2v or t2v - boundary_ratio=boundary_ratio_t2v, + boundary_ratio=self.boundary_ratio, ) # pipeline = pipeline.to(self.device_torch) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 1787f0da9..14efbec92 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -111,6 +111,35 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): def before_model_load(self): pass + + def _calculate_grad_norm(self, params): + if params is None or len(params) == 0: + return None + + if isinstance(params[0], dict): + param_iterable = (p for group in params for p in group.get('params', [])) + else: + param_iterable = params + + total_norm_sq = None + for param in param_iterable: + if param is None: + continue + grad = getattr(param, 'grad', None) + if grad is None: + continue + if grad.is_sparse: + grad = grad.coalesce()._values() + grad_norm = grad.detach().float().norm(2) + if total_norm_sq is None: + total_norm_sq = grad_norm.pow(2) + else: + total_norm_sq = total_norm_sq + grad_norm.pow(2) + + if total_norm_sq is None: + return None + + return total_norm_sq.sqrt() def cache_sample_prompts(self): if self.train_config.disable_sampling: @@ -2030,7 +2059,11 @@ def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchD torch.cuda.empty_cache() + grad_norm_value = None if not self.is_grad_accumulation_step: + grad_norm_tensor = self._calculate_grad_norm(self.params) + if grad_norm_tensor is not None: + grad_norm_value = grad_norm_tensor.item() # fix this for multi params if self.train_config.optimizer != 'adafactor': if isinstance(self.params[0], dict): @@ -2068,6 +2101,8 @@ def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchD loss_dict = OrderedDict( {'loss': (total_loss / len(batch_list)).item()} ) + if grad_norm_value is not None: + loss_dict['grad_norm'] = grad_norm_value self.end_of_training_loop() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index d376b2991..4b74e3214 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1790,6 +1790,8 @@ def run(self): config['default_lr'] = self.train_config.lr if 'learning_rate' in sig.parameters: config['learning_rate'] = self.train_config.lr + if 'optimizer_params' in sig.parameters: + config['optimizer_params'] = self.train_config.optimizer_params params_net = self.network.prepare_optimizer_params( **config ) @@ -2176,7 +2178,7 @@ def run(self): if self.torch_profiler is not None: torch.cuda.synchronize() # Make sure all CUDA ops are done self.torch_profiler.stop() - + print("\n==== Profile Results ====") print(self.torch_profiler.key_averages().table(sort_by="cpu_time_total", row_limit=1000)) self.timer.stop('train_loop') @@ -2188,28 +2190,45 @@ def run(self): if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): self.adapter.clear_memory() - with torch.no_grad(): - # torch.cuda.empty_cache() - # if optimizer has get_lrs method, then use it - if hasattr(optimizer, 'get_avg_learning_rate'): - learning_rate = optimizer.get_avg_learning_rate() - elif hasattr(optimizer, 'get_learning_rates'): - learning_rate = optimizer.get_learning_rates()[0] - elif self.train_config.optimizer.lower().startswith('dadaptation') or \ - self.train_config.optimizer.lower().startswith('prodigy'): - learning_rate = ( - optimizer.param_groups[0]["d"] * - optimizer.param_groups[0]["lr"] - ) - else: - learning_rate = optimizer.param_groups[0]['lr'] - - prog_bar_string = f"lr: {learning_rate:.1e}" - for key, value in loss_dict.items(): - prog_bar_string += f" {key}: {value:.3e}" + # Only update progress bar if we didn't OOM (loss_dict exists) + if not did_oom: + with torch.no_grad(): + # torch.cuda.empty_cache() + # if optimizer has get_lrs method, then use it + if hasattr(optimizer, 'get_avg_learning_rate'): + # Check if this is MoE with multiple param groups + if hasattr(optimizer, 'get_learning_rates') and len(optimizer.param_groups) > 1: + # Show per-expert LRs for MoE + group_lrs = optimizer.get_learning_rates() + learning_rate = None # Will use group_lrs instead + else: + learning_rate = optimizer.get_avg_learning_rate() + elif hasattr(optimizer, 'get_learning_rates'): + learning_rate = optimizer.get_learning_rates()[0] + elif self.train_config.optimizer.lower().startswith('dadaptation') or \ + self.train_config.optimizer.lower().startswith('prodigy'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] + + # Format LR string (per-expert for MoE, single value otherwise) + if hasattr(optimizer, 'get_avg_learning_rate') and learning_rate is None: + # MoE: show each expert's LR + lr_strings = [] + for i, lr in enumerate(group_lrs): + lr_val = lr.item() if hasattr(lr, 'item') else lr + lr_strings.append(f"lr{i}: {lr_val:.1e}") + prog_bar_string = " ".join(lr_strings) + else: + prog_bar_string = f"lr: {learning_rate:.1e}" + for key, value in loss_dict.items(): + prog_bar_string += f" {key}: {value:.3e}" - if self.progress_bar is not None: - self.progress_bar.set_postfix_str(prog_bar_string) + if self.progress_bar is not None: + self.progress_bar.set_postfix_str(prog_bar_string) # if the batch is a DataLoaderBatchDTO, then we need to clean it up if isinstance(batch, DataLoaderBatchDTO): diff --git a/toolkit/buckets.py b/toolkit/buckets.py index 3b0cbf19d..8c3de7d78 100644 --- a/toolkit/buckets.py +++ b/toolkit/buckets.py @@ -6,7 +6,31 @@ class BucketResolution(TypedDict): height: int -# resolutions SDXL was trained on with a 1024x1024 base resolution +# Video-friendly resolutions with common aspect ratios +# Base resolution: 1024×1024 +# Keep only PRIMARY buckets to avoid videos being assigned to undersized buckets +resolutions_video_1024: List[BucketResolution] = [ + # Square + {"width": 1024, "height": 1024}, # 1:1 + + # 16:9 landscape (1.778 aspect - YouTube, TV standard) + {"width": 1024, "height": 576}, + + # 9:16 portrait (0.562 aspect - TikTok, Instagram Reels) + {"width": 576, "height": 1024}, + + # 4:3 landscape (1.333 aspect - older content) + {"width": 1024, "height": 768}, + + # 3:4 portrait (0.75 aspect) + {"width": 768, "height": 1024}, + + # Slightly wider/taller variants for flexibility + {"width": 1024, "height": 640}, # 1.6 aspect + {"width": 640, "height": 1024}, # 0.625 aspect +] + +# SDXL resolutions (kept for backwards compatibility) resolutions_1024: List[BucketResolution] = [ # SDXL Base resolution {"width": 1024, "height": 1024}, @@ -56,12 +80,48 @@ class BucketResolution(TypedDict): {"width": 128, "height": 8192}, ] -def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[BucketResolution]: - # determine scaler form 1024 to resolution - scaler = resolution / 1024 +def get_bucket_sizes(resolution: int = 512, divisibility: int = 8, use_video_buckets: bool = True, max_pixels_per_frame: int = None) -> List[BucketResolution]: + # Use video-friendly buckets by default for better aspect ratio preservation + base_resolutions = resolutions_video_1024 if use_video_buckets else resolutions_1024 + + # If max_pixels_per_frame is specified, use pixel budget scaling + # This maximizes resolution for each aspect ratio while keeping memory usage consistent + if max_pixels_per_frame is not None: + bucket_size_list = [] + for bucket in base_resolutions: + # Calculate aspect ratio + base_aspect = bucket["width"] / bucket["height"] + + # Calculate optimal dimensions for this aspect ratio within pixel budget + # For aspect ratio a = w/h and pixel budget p = w*h: + # w = sqrt(p * a), h = sqrt(p / a) + optimal_width = (max_pixels_per_frame * base_aspect) ** 0.5 + optimal_height = (max_pixels_per_frame / base_aspect) ** 0.5 + + # Round down to divisibility + width = int(optimal_width) + height = int(optimal_height) + width = width - (width % divisibility) + height = height - (height % divisibility) + + # Verify we're under budget (should always be true with round-down) + actual_pixels = width * height + if actual_pixels > max_pixels_per_frame: + # Safety check - scale down if somehow over budget + scale = (max_pixels_per_frame / actual_pixels) ** 0.5 + width = int(width * scale) + height = int(height * scale) + width = width - (width % divisibility) + height = height - (height % divisibility) + + bucket_size_list.append({"width": width, "height": height}) + return bucket_size_list + + # Original scaling logic (for backwards compatibility) + scaler = resolution / 1024 bucket_size_list = [] - for bucket in resolutions_1024: + for bucket in base_resolutions: # must be divisible by 8 width = int(bucket["width"] * scaler) height = int(bucket["height"] * scaler) @@ -69,6 +129,12 @@ def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[Bucke width = width - (width % divisibility) if height % divisibility != 0: height = height - (height % divisibility) + + # Filter buckets where any dimension exceeds the resolution parameter + # This ensures memory usage stays within bounds for the target resolution + if max(width, height) > resolution: + continue + bucket_size_list.append({"width": width, "height": height}) return bucket_size_list @@ -86,17 +152,15 @@ def get_bucket_for_image_size( height: int, bucket_size_list: List[BucketResolution] = None, resolution: Union[int, None] = None, - divisibility: int = 8 + divisibility: int = 8, + max_pixels_per_frame: int = None ) -> BucketResolution: if bucket_size_list is None and resolution is None: # get resolution from width and height resolution = get_resolution(width, height) if bucket_size_list is None: - # if real resolution is smaller, use that instead - real_resolution = get_resolution(width, height) - resolution = min(resolution, real_resolution) - bucket_size_list = get_bucket_sizes(resolution=resolution, divisibility=divisibility) + bucket_size_list = get_bucket_sizes(resolution=resolution, divisibility=divisibility, max_pixels_per_frame=max_pixels_per_frame) # Check for exact match first for bucket in bucket_size_list: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 1fa7c688c..fff21e75e 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -836,6 +836,7 @@ def __init__(self, **kwargs): self.random_scale: bool = kwargs.get('random_scale', False) self.random_crop: bool = kwargs.get('random_crop', False) self.resolution: int = kwargs.get('resolution', 512) + self.max_pixels_per_frame: int = kwargs.get('max_pixels_per_frame', None) self.scale: float = kwargs.get('scale', 1.0) self.buckets: bool = kwargs.get('buckets', True) self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64) diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 95075a61a..f8ac8954f 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -332,6 +332,7 @@ def __getitem__(self, index): width=img2.width, height=img2.height, resolution=self.size, + max_pixels_per_frame=getattr(self.dataset_config, 'max_pixels_per_frame', None) if hasattr(self, 'dataset_config') else None # divisibility=self. ) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 3fe105924..288940e07 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -209,6 +209,7 @@ def setup_buckets(self: 'AiToolkitDataset', quiet=False): config: 'DatasetConfig' = self.dataset_config resolution = config.resolution bucket_tolerance = config.bucket_tolerance + max_pixels_per_frame = config.max_pixels_per_frame file_list: List['FileItemDTO'] = self.file_list # for file_item in enumerate(file_list): @@ -240,7 +241,8 @@ def setup_buckets(self: 'AiToolkitDataset', quiet=False): bucket_resolution = get_bucket_for_image_size( width, height, resolution=resolution, - divisibility=bucket_tolerance + divisibility=bucket_tolerance, + max_pixels_per_frame=max_pixels_per_frame ) # Calculate scale factors for width and height @@ -1601,7 +1603,8 @@ def setup_poi_bucket(self: 'FileItemDTO'): bucket_resolution = get_bucket_for_image_size( new_width, new_height, resolution=self.dataset_config.resolution, - divisibility=bucket_tolerance + divisibility=bucket_tolerance, + max_pixels_per_frame=self.dataset_config.max_pixels_per_frame ) width_scale_factor = bucket_resolution["width"] / new_width diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index cd4546561..168a04a2c 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -570,10 +570,110 @@ def create_modules( unet.conv_in = self.unet_conv_in unet.conv_out = self.unet_conv_out - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): - # call Lora prepare_optimizer_params + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, optimizer_params=None): + # Check if we're training a WAN 2.2 14B MoE model + base_model = self.base_model_ref() if self.base_model_ref is not None else None + is_wan22_moe = base_model is not None and hasattr(base_model, 'arch') and base_model.arch in ["wan22_14b", "wan22_14b_i2v"] + + # If MoE model and optimizer_params provided, split param groups for high/low noise experts + if is_wan22_moe and optimizer_params is not None and self.unet_loras: + return self._prepare_moe_optimizer_params(text_encoder_lr, unet_lr, default_lr, optimizer_params) + + # Otherwise use standard param group creation all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr) + if self.full_train_in_out: + if self.is_pixart or self.is_auraflow or self.is_flux or (base_model is not None and base_model.arch == "wan21"): + all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())}) + all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())}) + else: + all_params.append({"lr": unet_lr, "params": list(self.unet_conv_in.parameters())}) + all_params.append({"lr": unet_lr, "params": list(self.unet_conv_out.parameters())}) + + return all_params + + def _prepare_moe_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, optimizer_params): + """ + Prepare optimizer params with separate groups for High Noise and Low Noise experts. + Allows per-expert lr_bump, min_lr, max_lr configuration for automagic optimizer. + """ + self.requires_grad_(True) + all_params = [] + + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + # Handle text encoder loras (standard, no splitting) + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) + + # Split unet_loras by transformer (High Noise = transformer_1, Low Noise = transformer_2) + if self.unet_loras: + high_noise_loras = [] + low_noise_loras = [] + other_loras = [] + + for lora in self.unet_loras: + # Note: lora_name uses $$ as separator, so check for 'transformer_1' substring + # This correctly matches names like "transformer$$transformer_1$$blocks$$0$$attn1$$to_q" + if 'transformer_1' in lora.lora_name: + high_noise_loras.append(lora) + elif 'transformer_2' in lora.lora_name: + low_noise_loras.append(lora) + else: + other_loras.append(lora) + + # Extract per-expert optimizer params with fallback to defaults + default_lr_bump = optimizer_params.get('lr_bump') + default_min_lr = optimizer_params.get('min_lr') + default_max_lr = optimizer_params.get('max_lr') + + # High Noise Expert param group + if high_noise_loras: + high_noise_params = {"params": enumerate_params(high_noise_loras)} + if unet_lr is not None: + high_noise_params["lr"] = unet_lr + + # Add per-expert optimizer params if using automagic + if default_lr_bump is not None: + high_noise_params["lr_bump"] = optimizer_params.get('high_noise_lr_bump', default_lr_bump) + if default_min_lr is not None: + high_noise_params["min_lr"] = optimizer_params.get('high_noise_min_lr', default_min_lr) + if default_max_lr is not None: + high_noise_params["max_lr"] = optimizer_params.get('high_noise_max_lr', default_max_lr) + + all_params.append(high_noise_params) + + # Low Noise Expert param group + if low_noise_loras: + low_noise_params = {"params": enumerate_params(low_noise_loras)} + if unet_lr is not None: + low_noise_params["lr"] = unet_lr + + # Add per-expert optimizer params if using automagic + if default_lr_bump is not None: + low_noise_params["lr_bump"] = optimizer_params.get('low_noise_lr_bump', default_lr_bump) + if default_min_lr is not None: + low_noise_params["min_lr"] = optimizer_params.get('low_noise_min_lr', default_min_lr) + if default_max_lr is not None: + low_noise_params["max_lr"] = optimizer_params.get('low_noise_max_lr', default_max_lr) + + all_params.append(low_noise_params) + + # Other loras (not transformer-specific) - use defaults + if other_loras: + other_params = {"params": enumerate_params(other_loras)} + if unet_lr is not None: + other_params["lr"] = unet_lr + all_params.append(other_params) + + # Add full_train_in_out params if needed if self.full_train_in_out: base_model = self.base_model_ref() if self.base_model_ref is not None else None if self.is_pixart or self.is_auraflow or self.is_flux or (base_model is not None and base_model.arch == "wan21"):