Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions extensions_built_in/diffusion_models/wan22/wan22_14b_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,15 @@ def __init__(
self._wan_cache = None

self.is_multistage = True

# Detect if this is I2V or T2V model
self.is_i2v = 'i2v' in model_config.name_or_path.lower()
self.boundary_ratio = boundary_ratio_i2v if self.is_i2v else boundary_ratio_t2v

# multistage boundaries split the models up when sampling timesteps
# for wan 2.2 14b. the timesteps are 1000-875 for transformer 1 and 875-0 for transformer 2
self.multistage_boundaries: List[float] = [0.875, 0.0]
# for wan 2.2 14b I2V: timesteps 1000-900 for transformer 1 and 900-0 for transformer 2
# for wan 2.2 14b T2V: timesteps 1000-875 for transformer 1 and 875-0 for transformer 2
self.multistage_boundaries: List[float] = [self.boundary_ratio, 0.0]

self.train_high_noise = model_config.model_kwargs.get("train_high_noise", True)
self.train_low_noise = model_config.model_kwargs.get("train_low_noise", True)
Expand Down Expand Up @@ -339,7 +345,7 @@ def load_wan_transformer(self, transformer_path, subfolder=None):
transformer_2=transformer_2,
torch_dtype=self.torch_dtype,
device=self.device_torch,
boundary_ratio=boundary_ratio_t2v,
boundary_ratio=self.boundary_ratio,
low_vram=self.model_config.low_vram,
)

Expand All @@ -363,8 +369,7 @@ def get_generation_pipeline(self):
expand_timesteps=self._wan_expand_timesteps,
device=self.device_torch,
aggressive_offload=self.model_config.low_vram,
# todo detect if it is i2v or t2v
boundary_ratio=boundary_ratio_t2v,
boundary_ratio=self.boundary_ratio,
)

# pipeline = pipeline.to(self.device_torch)
Expand Down
35 changes: 35 additions & 0 deletions extensions_built_in/sd_trainer/SDTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,35 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):

def before_model_load(self):
pass

def _calculate_grad_norm(self, params):
if params is None or len(params) == 0:
return None

if isinstance(params[0], dict):
param_iterable = (p for group in params for p in group.get('params', []))
else:
param_iterable = params

total_norm_sq = None
for param in param_iterable:
if param is None:
continue
grad = getattr(param, 'grad', None)
if grad is None:
continue
if grad.is_sparse:
grad = grad.coalesce()._values()
grad_norm = grad.detach().float().norm(2)
if total_norm_sq is None:
total_norm_sq = grad_norm.pow(2)
else:
total_norm_sq = total_norm_sq + grad_norm.pow(2)

if total_norm_sq is None:
return None

return total_norm_sq.sqrt()

def cache_sample_prompts(self):
if self.train_config.disable_sampling:
Expand Down Expand Up @@ -2030,7 +2059,11 @@ def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchD
torch.cuda.empty_cache()


grad_norm_value = None
if not self.is_grad_accumulation_step:
grad_norm_tensor = self._calculate_grad_norm(self.params)
if grad_norm_tensor is not None:
grad_norm_value = grad_norm_tensor.item()
# fix this for multi params
if self.train_config.optimizer != 'adafactor':
if isinstance(self.params[0], dict):
Expand Down Expand Up @@ -2068,6 +2101,8 @@ def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchD
loss_dict = OrderedDict(
{'loss': (total_loss / len(batch_list)).item()}
)
if grad_norm_value is not None:
loss_dict['grad_norm'] = grad_norm_value

self.end_of_training_loop()

Expand Down
63 changes: 41 additions & 22 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -1790,6 +1790,8 @@ def run(self):
config['default_lr'] = self.train_config.lr
if 'learning_rate' in sig.parameters:
config['learning_rate'] = self.train_config.lr
if 'optimizer_params' in sig.parameters:
config['optimizer_params'] = self.train_config.optimizer_params
params_net = self.network.prepare_optimizer_params(
**config
)
Expand Down Expand Up @@ -2176,7 +2178,7 @@ def run(self):
if self.torch_profiler is not None:
torch.cuda.synchronize() # Make sure all CUDA ops are done
self.torch_profiler.stop()

print("\n==== Profile Results ====")
print(self.torch_profiler.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
self.timer.stop('train_loop')
Expand All @@ -2188,28 +2190,45 @@ def run(self):
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
self.adapter.clear_memory()

with torch.no_grad():
# torch.cuda.empty_cache()
# if optimizer has get_lrs method, then use it
if hasattr(optimizer, 'get_avg_learning_rate'):
learning_rate = optimizer.get_avg_learning_rate()
elif hasattr(optimizer, 'get_learning_rates'):
learning_rate = optimizer.get_learning_rates()[0]
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']

prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"
# Only update progress bar if we didn't OOM (loss_dict exists)
if not did_oom:
with torch.no_grad():
# torch.cuda.empty_cache()
# if optimizer has get_lrs method, then use it
if hasattr(optimizer, 'get_avg_learning_rate'):
# Check if this is MoE with multiple param groups
if hasattr(optimizer, 'get_learning_rates') and len(optimizer.param_groups) > 1:
# Show per-expert LRs for MoE
group_lrs = optimizer.get_learning_rates()
learning_rate = None # Will use group_lrs instead
else:
learning_rate = optimizer.get_avg_learning_rate()
elif hasattr(optimizer, 'get_learning_rates'):
learning_rate = optimizer.get_learning_rates()[0]
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']

# Format LR string (per-expert for MoE, single value otherwise)
if hasattr(optimizer, 'get_avg_learning_rate') and learning_rate is None:
# MoE: show each expert's LR
lr_strings = []
for i, lr in enumerate(group_lrs):
lr_val = lr.item() if hasattr(lr, 'item') else lr
lr_strings.append(f"lr{i}: {lr_val:.1e}")
prog_bar_string = " ".join(lr_strings)
else:
prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"

if self.progress_bar is not None:
self.progress_bar.set_postfix_str(prog_bar_string)
if self.progress_bar is not None:
self.progress_bar.set_postfix_str(prog_bar_string)

# if the batch is a DataLoaderBatchDTO, then we need to clean it up
if isinstance(batch, DataLoaderBatchDTO):
Expand Down
84 changes: 74 additions & 10 deletions toolkit/buckets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,31 @@ class BucketResolution(TypedDict):
height: int


# resolutions SDXL was trained on with a 1024x1024 base resolution
# Video-friendly resolutions with common aspect ratios
# Base resolution: 1024×1024
# Keep only PRIMARY buckets to avoid videos being assigned to undersized buckets
resolutions_video_1024: List[BucketResolution] = [
# Square
{"width": 1024, "height": 1024}, # 1:1

# 16:9 landscape (1.778 aspect - YouTube, TV standard)
{"width": 1024, "height": 576},

# 9:16 portrait (0.562 aspect - TikTok, Instagram Reels)
{"width": 576, "height": 1024},

# 4:3 landscape (1.333 aspect - older content)
{"width": 1024, "height": 768},

# 3:4 portrait (0.75 aspect)
{"width": 768, "height": 1024},

# Slightly wider/taller variants for flexibility
{"width": 1024, "height": 640}, # 1.6 aspect
{"width": 640, "height": 1024}, # 0.625 aspect
]

# SDXL resolutions (kept for backwards compatibility)
resolutions_1024: List[BucketResolution] = [
# SDXL Base resolution
{"width": 1024, "height": 1024},
Expand Down Expand Up @@ -56,19 +80,61 @@ class BucketResolution(TypedDict):
{"width": 128, "height": 8192},
]

def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[BucketResolution]:
# determine scaler form 1024 to resolution
scaler = resolution / 1024
def get_bucket_sizes(resolution: int = 512, divisibility: int = 8, use_video_buckets: bool = True, max_pixels_per_frame: int = None) -> List[BucketResolution]:
# Use video-friendly buckets by default for better aspect ratio preservation
base_resolutions = resolutions_video_1024 if use_video_buckets else resolutions_1024

# If max_pixels_per_frame is specified, use pixel budget scaling
# This maximizes resolution for each aspect ratio while keeping memory usage consistent
if max_pixels_per_frame is not None:
bucket_size_list = []
for bucket in base_resolutions:
# Calculate aspect ratio
base_aspect = bucket["width"] / bucket["height"]

# Calculate optimal dimensions for this aspect ratio within pixel budget
# For aspect ratio a = w/h and pixel budget p = w*h:
# w = sqrt(p * a), h = sqrt(p / a)
optimal_width = (max_pixels_per_frame * base_aspect) ** 0.5
optimal_height = (max_pixels_per_frame / base_aspect) ** 0.5

# Round down to divisibility
width = int(optimal_width)
height = int(optimal_height)
width = width - (width % divisibility)
height = height - (height % divisibility)

# Verify we're under budget (should always be true with round-down)
actual_pixels = width * height
if actual_pixels > max_pixels_per_frame:
# Safety check - scale down if somehow over budget
scale = (max_pixels_per_frame / actual_pixels) ** 0.5
width = int(width * scale)
height = int(height * scale)
width = width - (width % divisibility)
height = height - (height % divisibility)

bucket_size_list.append({"width": width, "height": height})

return bucket_size_list

# Original scaling logic (for backwards compatibility)
scaler = resolution / 1024
bucket_size_list = []
for bucket in resolutions_1024:
for bucket in base_resolutions:
# must be divisible by 8
width = int(bucket["width"] * scaler)
height = int(bucket["height"] * scaler)
if width % divisibility != 0:
width = width - (width % divisibility)
if height % divisibility != 0:
height = height - (height % divisibility)

# Filter buckets where any dimension exceeds the resolution parameter
# This ensures memory usage stays within bounds for the target resolution
if max(width, height) > resolution:
continue

bucket_size_list.append({"width": width, "height": height})

return bucket_size_list
Expand All @@ -86,17 +152,15 @@ def get_bucket_for_image_size(
height: int,
bucket_size_list: List[BucketResolution] = None,
resolution: Union[int, None] = None,
divisibility: int = 8
divisibility: int = 8,
max_pixels_per_frame: int = None
) -> BucketResolution:

if bucket_size_list is None and resolution is None:
# get resolution from width and height
resolution = get_resolution(width, height)
if bucket_size_list is None:
# if real resolution is smaller, use that instead
real_resolution = get_resolution(width, height)
resolution = min(resolution, real_resolution)
bucket_size_list = get_bucket_sizes(resolution=resolution, divisibility=divisibility)
bucket_size_list = get_bucket_sizes(resolution=resolution, divisibility=divisibility, max_pixels_per_frame=max_pixels_per_frame)

# Check for exact match first
for bucket in bucket_size_list:
Expand Down
1 change: 1 addition & 0 deletions toolkit/config_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,7 @@ def __init__(self, **kwargs):
self.random_scale: bool = kwargs.get('random_scale', False)
self.random_crop: bool = kwargs.get('random_crop', False)
self.resolution: int = kwargs.get('resolution', 512)
self.max_pixels_per_frame: int = kwargs.get('max_pixels_per_frame', None)
self.scale: float = kwargs.get('scale', 1.0)
self.buckets: bool = kwargs.get('buckets', True)
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
Expand Down
1 change: 1 addition & 0 deletions toolkit/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def __getitem__(self, index):
width=img2.width,
height=img2.height,
resolution=self.size,
max_pixels_per_frame=getattr(self.dataset_config, 'max_pixels_per_frame', None) if hasattr(self, 'dataset_config') else None
# divisibility=self.
)

Expand Down
7 changes: 5 additions & 2 deletions toolkit/dataloader_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def setup_buckets(self: 'AiToolkitDataset', quiet=False):
config: 'DatasetConfig' = self.dataset_config
resolution = config.resolution
bucket_tolerance = config.bucket_tolerance
max_pixels_per_frame = config.max_pixels_per_frame
file_list: List['FileItemDTO'] = self.file_list

# for file_item in enumerate(file_list):
Expand Down Expand Up @@ -240,7 +241,8 @@ def setup_buckets(self: 'AiToolkitDataset', quiet=False):
bucket_resolution = get_bucket_for_image_size(
width, height,
resolution=resolution,
divisibility=bucket_tolerance
divisibility=bucket_tolerance,
max_pixels_per_frame=max_pixels_per_frame
)

# Calculate scale factors for width and height
Expand Down Expand Up @@ -1601,7 +1603,8 @@ def setup_poi_bucket(self: 'FileItemDTO'):
bucket_resolution = get_bucket_for_image_size(
new_width, new_height,
resolution=self.dataset_config.resolution,
divisibility=bucket_tolerance
divisibility=bucket_tolerance,
max_pixels_per_frame=self.dataset_config.max_pixels_per_frame
)

width_scale_factor = bucket_resolution["width"] / new_width
Expand Down
Loading