Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix: WAN 2.2 I2V boundary detection, AdamW8bit OOM crash, and add gra…
…dient norm logging

This commit includes three critical fixes and one feature addition:

1. WAN 2.2 I2V Boundary Detection Fix:
   - Auto-detect I2V vs T2V models from model path
   - Use correct boundary ratio (0.9 for I2V, 0.875 for T2V)
   - Previous hardcoded T2V boundary caused training issues for I2V models
   - Fixes timestep distribution for dual LoRA (HIGH/LOW noise) training

2. AdamW8bit OOM Loss Access Fix:
   - Prevent crash when accessing loss_dict after OOM event
   - Only update progress bar if training step succeeded (not did_oom)
   - Resolves KeyError when loss_dict is not populated due to OOM

3. Gradient Norm Logging:
   - Add _calculate_grad_norm() method for comprehensive gradient tracking
   - Handles sparse gradients and param groups correctly
   - Logs grad_norm in loss_dict for monitoring training stability
   - Essential for diagnosing divergence and LR issues

These fixes improve training stability and monitoring for WAN 2.2 I2V/T2V models.
  • Loading branch information
AI Toolkit Contributor committed Oct 22, 2025
commit cf67c367201864b2db465b7e2d828ab3cb7d5a68
15 changes: 10 additions & 5 deletions extensions_built_in/diffusion_models/wan22/wan22_14b_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,15 @@ def __init__(
self._wan_cache = None

self.is_multistage = True

# Detect if this is I2V or T2V model
self.is_i2v = 'i2v' in model_config.name_or_path.lower()
self.boundary_ratio = boundary_ratio_i2v if self.is_i2v else boundary_ratio_t2v

# multistage boundaries split the models up when sampling timesteps
# for wan 2.2 14b. the timesteps are 1000-875 for transformer 1 and 875-0 for transformer 2
self.multistage_boundaries: List[float] = [0.875, 0.0]
# for wan 2.2 14b I2V: timesteps 1000-900 for transformer 1 and 900-0 for transformer 2
# for wan 2.2 14b T2V: timesteps 1000-875 for transformer 1 and 875-0 for transformer 2
self.multistage_boundaries: List[float] = [self.boundary_ratio, 0.0]

self.train_high_noise = model_config.model_kwargs.get("train_high_noise", True)
self.train_low_noise = model_config.model_kwargs.get("train_low_noise", True)
Expand Down Expand Up @@ -339,7 +345,7 @@ def load_wan_transformer(self, transformer_path, subfolder=None):
transformer_2=transformer_2,
torch_dtype=self.torch_dtype,
device=self.device_torch,
boundary_ratio=boundary_ratio_t2v,
boundary_ratio=self.boundary_ratio,
low_vram=self.model_config.low_vram,
)

Expand All @@ -363,8 +369,7 @@ def get_generation_pipeline(self):
expand_timesteps=self._wan_expand_timesteps,
device=self.device_torch,
aggressive_offload=self.model_config.low_vram,
# todo detect if it is i2v or t2v
boundary_ratio=boundary_ratio_t2v,
boundary_ratio=self.boundary_ratio,
)

# pipeline = pipeline.to(self.device_torch)
Expand Down
35 changes: 35 additions & 0 deletions extensions_built_in/sd_trainer/SDTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,35 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):

def before_model_load(self):
pass

def _calculate_grad_norm(self, params):
if params is None or len(params) == 0:
return None

if isinstance(params[0], dict):
param_iterable = (p for group in params for p in group.get('params', []))
else:
param_iterable = params

total_norm_sq = None
for param in param_iterable:
if param is None:
continue
grad = getattr(param, 'grad', None)
if grad is None:
continue
if grad.is_sparse:
grad = grad.coalesce()._values()
grad_norm = grad.detach().float().norm(2)
if total_norm_sq is None:
total_norm_sq = grad_norm.pow(2)
else:
total_norm_sq = total_norm_sq + grad_norm.pow(2)

if total_norm_sq is None:
return None

return total_norm_sq.sqrt()

def cache_sample_prompts(self):
if self.train_config.disable_sampling:
Expand Down Expand Up @@ -2030,7 +2059,11 @@ def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchD
torch.cuda.empty_cache()


grad_norm_value = None
if not self.is_grad_accumulation_step:
grad_norm_tensor = self._calculate_grad_norm(self.params)
if grad_norm_tensor is not None:
grad_norm_value = grad_norm_tensor.item()
# fix this for multi params
if self.train_config.optimizer != 'adafactor':
if isinstance(self.params[0], dict):
Expand Down Expand Up @@ -2068,6 +2101,8 @@ def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchD
loss_dict = OrderedDict(
{'loss': (total_loss / len(batch_list)).item()}
)
if grad_norm_value is not None:
loss_dict['grad_norm'] = grad_norm_value

self.end_of_training_loop()

Expand Down
44 changes: 23 additions & 21 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2176,7 +2176,7 @@ def run(self):
if self.torch_profiler is not None:
torch.cuda.synchronize() # Make sure all CUDA ops are done
self.torch_profiler.stop()

print("\n==== Profile Results ====")
print(self.torch_profiler.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
self.timer.stop('train_loop')
Expand All @@ -2188,28 +2188,30 @@ def run(self):
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
self.adapter.clear_memory()

with torch.no_grad():
# torch.cuda.empty_cache()
# if optimizer has get_lrs method, then use it
if hasattr(optimizer, 'get_avg_learning_rate'):
learning_rate = optimizer.get_avg_learning_rate()
elif hasattr(optimizer, 'get_learning_rates'):
learning_rate = optimizer.get_learning_rates()[0]
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
# Only update progress bar if we didn't OOM (loss_dict exists)
if not did_oom:
with torch.no_grad():
# torch.cuda.empty_cache()
# if optimizer has get_lrs method, then use it
if hasattr(optimizer, 'get_avg_learning_rate'):
learning_rate = optimizer.get_avg_learning_rate()
elif hasattr(optimizer, 'get_learning_rates'):
learning_rate = optimizer.get_learning_rates()[0]
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']

prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"
prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"

if self.progress_bar is not None:
self.progress_bar.set_postfix_str(prog_bar_string)
if self.progress_bar is not None:
self.progress_bar.set_postfix_str(prog_bar_string)

# if the batch is a DataLoaderBatchDTO, then we need to clean it up
if isinstance(batch, DataLoaderBatchDTO):
Expand Down