Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
5e5e9db
Fix: WAN 2.2 I2V boundary detection, AdamW8bit OOM crash, and add gra…
Oct 22, 2025
12e2b37
Improve video training with better bucket allocation
Oct 28, 2025
a1f70bc
Fix MoE training: per-expert LR logging and param group splitting
Oct 28, 2025
a2749c5
Add progressive alpha scheduling and comprehensive metrics tracking f…
Oct 29, 2025
86d107e
Merge remote-tracking branch 'upstream/main'
Oct 29, 2025
c91628e
Update README with comprehensive fork documentation and alpha schedul…
Oct 29, 2025
61143d6
Add comprehensive beginner-friendly documentation and UI improvements
Oct 29, 2025
96b1bda
Remove sponsors section from README - this is a fork without sponsors
Oct 29, 2025
bce9866
Fix confusing expert metrics display - add current training status
Oct 29, 2025
bd45a9e
Fix UnboundLocalError: remove redundant local 'import os'
Oct 29, 2025
abbe765
Add metrics API endpoint and UI components for real-time training mon…
Oct 29, 2025
edaf27d
Fix: Always show Loss Trend Analysis section with collection progress
Oct 29, 2025
a551b65
Fix: SVG charts now display correctly - add viewBox for proper coordi…
Oct 29, 2025
1682199
Fix: Downsample metrics to 500 points and lower phase transition thre…
Oct 30, 2025
885bbd4
Add comprehensive training recommendations based on research
Oct 30, 2025
705c5d3
Fix TRAINING_RECOMMENDATIONS for motion training
Oct 30, 2025
54c059a
Fix metrics to use EMA instead of simple averages
Oct 30, 2025
20b3c12
FIX CRITICAL BUG: Training loop re-doing checkpoint step on resume
Oct 30, 2025
226d19d
Remove useless checkpoint analyzer script
Oct 30, 2025
66978dd
Fix: Export EMA metrics to JSONL for UI visualization
Oct 30, 2025
fa12a08
Fix: Optimizer state loading counting wrong number of params for MoE
Oct 30, 2025
264c162
Fix: Set current_expert_name for metrics tracking
Oct 30, 2025
aecc467
Fix alpha scheduler not loading for MoE models on resume
Oct 31, 2025
b1ea60f
feat: Add SageAttention support for Wan models
Nov 4, 2025
20d689d
Fix CRITICAL metrics regression: boundary misalignment on resume + ad…
Nov 4, 2025
8b8506c
Merge feature/sageattention-wan-support into main
Nov 4, 2025
6a7ecac
docs: Update README with SageAttention and metrics fixes
Nov 4, 2025
850db0f
docs: Update installation instructions to use PyTorch nightly
Nov 4, 2025
26e9bdb
docs: Major README overhaul - Focus on Wan 2.2 I2V optimization
Nov 4, 2025
88785a9
docs: Fix Blackwell CUDA requirements - CUDA 13.0 not 12.8
Nov 4, 2025
0cacab8
Fix: torchao quantized tensors don't support copy argument in .to()
Nov 4, 2025
3ad8bfb
Fix critical FP16 hardcoding causing low-noise training instability
Nov 4, 2025
8589967
Fix metrics UI cross-contamination in per-expert windows
Nov 4, 2025
47dff0d
Fix FP16 hardcoding in TrainSliderProcess mask processing
Nov 4, 2025
eeeeb2e
Fix LR scheduler stepping to respect gradient accumulation
Nov 4, 2025
f026f35
CRITICAL: Fix VAE dtype mismatch in Wan encode_images
Nov 5, 2025
c7c3459
CRITICAL: Revert CFG-zero to be optional (match Ostris Nov 4 update)
Nov 5, 2025
728b46d
CRITICAL: Fix multiple SageAttention bugs causing training instability
Nov 5, 2025
7c9b205
Additional SageAttention and VAE dtype refinements
Nov 5, 2025
1d9dc98
Fix rotary embedding application to match Diffusers WAN reference
Nov 5, 2025
67445b9
Add temporal_jitter parameter for video frame sampling
Nov 5, 2025
ab59f00
Document temporal_jitter feature in README
Nov 5, 2025
80ff3db
Fix VAE dtype handling for WAN 2.2 I2V training to prevent blurry sam…
Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix MoE training: per-expert LR logging and param group splitting
This commit fixes two critical issues with Mixture of Experts (MoE) training
for dual-transformer models like WAN 2.2 14B I2V:

**Issue 1: Averaged LR logging masked expert-specific behavior**
- Previous logging averaged LR across all param groups (both experts)
- Made it impossible to verify LR was resuming correctly per expert
- Example: High Noise at 0.0005, Low Noise at 0.00001 → logged as 0.00026

**Fix:** Per-expert LR display (BaseSDTrainProcess.py lines 2198-2226)
- Detects MoE via multiple param groups
- Shows separate LR for each expert: "lr0: 5.0e-04 lr1: 3.5e-05"
- Makes expert-specific LR adaptation visible and debuggable

**Issue 2: Transformer detection bug prevented param group splitting**
- _prepare_moe_optimizer_params() checked for '.transformer_1.' (dots)
- But lora_name uses '$$' separator: "transformer$$transformer_1$$blocks..."
- Check never matched, all params went into single group → no per-expert LRs

**Fix:** Corrected substring matching (lora_special.py lines 622-630)
- Changed from '.transformer_1.' to 'transformer_1' substring check
- Now correctly creates separate param groups for transformer_1/transformer_2
- Enables per-expert lr_bump, min_lr, max_lr with automagic optimizer

**Result:**
- Visible per-expert LR adaptation: lr0 and lr1 tracked independently
- Proper LR state preservation when experts switch every N steps
- Accurate monitoring of training progress for each expert

Example output:
```
lr0: 2.8e-05 lr1: 0.0e+00 loss: 8.414e-02  # High Noise active
lr0: 5.2e-05 lr1: 1.0e-05 loss: 7.821e-02  # After switch to Low Noise
lr0: 5.2e-05 lr1: 3.4e-05 loss: 6.103e-02  # Low Noise adapting, High preserved
```

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
  • Loading branch information
AI Toolkit Contributor and claude committed Oct 29, 2025
commit a1f70bc513582c3c80a9cbce17402060b0baefcc
21 changes: 19 additions & 2 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,6 +1791,8 @@ def run(self):
config['default_lr'] = self.train_config.lr
if 'learning_rate' in sig.parameters:
config['learning_rate'] = self.train_config.lr
if 'optimizer_params' in sig.parameters:
config['optimizer_params'] = self.train_config.optimizer_params
params_net = self.network.prepare_optimizer_params(
**config
)
Expand Down Expand Up @@ -2203,7 +2205,13 @@ def run(self):
# torch.cuda.empty_cache()
# if optimizer has get_lrs method, then use it
if hasattr(optimizer, 'get_avg_learning_rate'):
learning_rate = optimizer.get_avg_learning_rate()
# Check if this is MoE with multiple param groups
if hasattr(optimizer, 'get_learning_rates') and len(optimizer.param_groups) > 1:
# Show per-expert LRs for MoE
group_lrs = optimizer.get_learning_rates()
learning_rate = None # Will use group_lrs instead
else:
learning_rate = optimizer.get_avg_learning_rate()
elif hasattr(optimizer, 'get_learning_rates'):
learning_rate = optimizer.get_learning_rates()[0]
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
Expand All @@ -2215,7 +2223,16 @@ def run(self):
else:
learning_rate = optimizer.param_groups[0]['lr']

prog_bar_string = f"lr: {learning_rate:.1e}"
# Format LR string (per-expert for MoE, single value otherwise)
if hasattr(optimizer, 'get_avg_learning_rate') and learning_rate is None:
# MoE: show each expert's LR
lr_strings = []
for i, lr in enumerate(group_lrs):
lr_val = lr.item() if hasattr(lr, 'item') else lr
lr_strings.append(f"lr{i}: {lr_val:.1e}")
prog_bar_string = " ".join(lr_strings)
else:
prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"

Expand Down
104 changes: 102 additions & 2 deletions toolkit/lora_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,10 +570,110 @@ def create_modules(
unet.conv_in = self.unet_conv_in
unet.conv_out = self.unet_conv_out

def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
# call Lora prepare_optimizer_params
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, optimizer_params=None):
# Check if we're training a WAN 2.2 14B MoE model
base_model = self.base_model_ref() if self.base_model_ref is not None else None
is_wan22_moe = base_model is not None and hasattr(base_model, 'arch') and base_model.arch in ["wan22_14b", "wan22_14b_i2v"]

# If MoE model and optimizer_params provided, split param groups for high/low noise experts
if is_wan22_moe and optimizer_params is not None and self.unet_loras:
return self._prepare_moe_optimizer_params(text_encoder_lr, unet_lr, default_lr, optimizer_params)

# Otherwise use standard param group creation
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)

if self.full_train_in_out:
if self.is_pixart or self.is_auraflow or self.is_flux or (base_model is not None and base_model.arch == "wan21"):
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
else:
all_params.append({"lr": unet_lr, "params": list(self.unet_conv_in.parameters())})
all_params.append({"lr": unet_lr, "params": list(self.unet_conv_out.parameters())})

return all_params

def _prepare_moe_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, optimizer_params):
"""
Prepare optimizer params with separate groups for High Noise and Low Noise experts.
Allows per-expert lr_bump, min_lr, max_lr configuration for automagic optimizer.
"""
self.requires_grad_(True)
all_params = []

def enumerate_params(loras):
params = []
for lora in loras:
params.extend(lora.parameters())
return params

# Handle text encoder loras (standard, no splitting)
if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)

# Split unet_loras by transformer (High Noise = transformer_1, Low Noise = transformer_2)
if self.unet_loras:
high_noise_loras = []
low_noise_loras = []
other_loras = []

for lora in self.unet_loras:
# Note: lora_name uses $$ as separator, so check for 'transformer_1' substring
# This correctly matches names like "transformer$$transformer_1$$blocks$$0$$attn1$$to_q"
if 'transformer_1' in lora.lora_name:
high_noise_loras.append(lora)
elif 'transformer_2' in lora.lora_name:
low_noise_loras.append(lora)
else:
other_loras.append(lora)

# Extract per-expert optimizer params with fallback to defaults
default_lr_bump = optimizer_params.get('lr_bump')
default_min_lr = optimizer_params.get('min_lr')
default_max_lr = optimizer_params.get('max_lr')

# High Noise Expert param group
if high_noise_loras:
high_noise_params = {"params": enumerate_params(high_noise_loras)}
if unet_lr is not None:
high_noise_params["lr"] = unet_lr

# Add per-expert optimizer params if using automagic
if default_lr_bump is not None:
high_noise_params["lr_bump"] = optimizer_params.get('high_noise_lr_bump', default_lr_bump)
if default_min_lr is not None:
high_noise_params["min_lr"] = optimizer_params.get('high_noise_min_lr', default_min_lr)
if default_max_lr is not None:
high_noise_params["max_lr"] = optimizer_params.get('high_noise_max_lr', default_max_lr)

all_params.append(high_noise_params)

# Low Noise Expert param group
if low_noise_loras:
low_noise_params = {"params": enumerate_params(low_noise_loras)}
if unet_lr is not None:
low_noise_params["lr"] = unet_lr

# Add per-expert optimizer params if using automagic
if default_lr_bump is not None:
low_noise_params["lr_bump"] = optimizer_params.get('low_noise_lr_bump', default_lr_bump)
if default_min_lr is not None:
low_noise_params["min_lr"] = optimizer_params.get('low_noise_min_lr', default_min_lr)
if default_max_lr is not None:
low_noise_params["max_lr"] = optimizer_params.get('low_noise_max_lr', default_max_lr)

all_params.append(low_noise_params)

# Other loras (not transformer-specific) - use defaults
if other_loras:
other_params = {"params": enumerate_params(other_loras)}
if unet_lr is not None:
other_params["lr"] = unet_lr
all_params.append(other_params)

# Add full_train_in_out params if needed
if self.full_train_in_out:
base_model = self.base_model_ref() if self.base_model_ref is not None else None
if self.is_pixart or self.is_auraflow or self.is_flux or (base_model is not None and base_model.arch == "wan21"):
Expand Down