Skip to content

Commit 86d107e

Browse files
author
AI Toolkit Contributor
committed
Merge remote-tracking branch 'upstream/main'
2 parents a2749c5 + 42e5e3c commit 86d107e

File tree

9 files changed

+193
-64
lines changed

9 files changed

+193
-64
lines changed

extensions_built_in/sd_trainer/SDTrainer.py

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,13 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
9595
raise ValueError("diff_output_preservation requires a network to be set")
9696
if self.train_config.train_text_encoder:
9797
raise ValueError("diff_output_preservation is not supported with train_text_encoder")
98-
99-
# always do a prior prediction when doing diff output preservation
98+
99+
if self.train_config.blank_prompt_preservation:
100+
if self.network_config is None:
101+
raise ValueError("blank_prompt_preservation requires a network to be set")
102+
103+
if self.train_config.blank_prompt_preservation or self.train_config.diff_output_preservation:
104+
# always do a prior prediction when doing output preservation
100105
self.do_prior_prediction = True
101106

102107
# store the loss target for a batch so we can use it in a loss
@@ -372,6 +377,13 @@ def hook_before_train_loop(self):
372377
self.sd.text_encoder_to("cpu")
373378
flush()
374379

380+
if self.train_config.blank_prompt_preservation and self.cached_blank_embeds is None:
381+
# make sure we have this if not unloading
382+
self.cached_blank_embeds = self.sd.encode_prompt("").to(
383+
self.device_torch,
384+
dtype=self.sd.torch_dtype
385+
).detach()
386+
375387
if self.train_config.diffusion_feature_extractor_path is not None:
376388
vae = self.sd.vae
377389
# if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
@@ -634,33 +646,28 @@ def calculate_loss(
634646
stepped_latents = torch.cat(stepped_chunks, dim=0)
635647

636648
stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype)
637-
# resize to half the size of the latents
638-
stepped_latents_half = torch.nn.functional.interpolate(
639-
stepped_latents,
640-
size=(stepped_latents.shape[2] // 2, stepped_latents.shape[3] // 2),
641-
mode='bilinear',
642-
align_corners=False
643-
)
644-
pred_features = self.dfe(stepped_latents.float())
645-
pred_features_half = self.dfe(stepped_latents_half.float())
649+
sl = stepped_latents
650+
if len(sl.shape) == 5:
651+
# video B,C,T,H,W
652+
sl = sl.permute(0, 2, 1, 3, 4) # B,T,C,H,W
653+
b, t, c, h, w = sl.shape
654+
sl = sl.reshape(b * t, c, h, w)
655+
pred_features = self.dfe(sl.float())
646656
with torch.no_grad():
647-
target_features = self.dfe(batch.latents.to(self.device_torch, dtype=torch.float32))
648-
batch_latents_half = torch.nn.functional.interpolate(
649-
batch.latents.to(self.device_torch, dtype=torch.float32),
650-
size=(batch.latents.shape[2] // 2, batch.latents.shape[3] // 2),
651-
mode='bilinear',
652-
align_corners=False
653-
)
654-
target_features_half = self.dfe(batch_latents_half)
657+
bl = batch.latents
658+
bl = bl.to(self.sd.vae.device)
659+
if len(bl.shape) == 5:
660+
# video B,C,T,H,W
661+
bl = bl.permute(0, 2, 1, 3, 4) # B,T,C,H,W
662+
b, t, c, h, w = bl.shape
663+
bl = bl.reshape(b * t, c, h, w)
664+
target_features = self.dfe(bl.float())
655665
# scale dfe so it is weaker at higher noise levels
656666
dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch)
657667

658668
dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \
659669
self.train_config.diffusion_feature_extractor_weight * dfe_scaler
660-
661-
dfe_loss_half = torch.nn.functional.mse_loss(pred_features_half, target_features_half, reduction="none") * \
662-
self.train_config.diffusion_feature_extractor_weight * dfe_scaler
663-
additional_loss += dfe_loss.mean() + dfe_loss_half.mean()
670+
additional_loss += dfe_loss.mean()
664671
elif self.dfe.version == 2:
665672
# version 2
666673
# do diffusion feature extraction on target
@@ -1798,6 +1805,14 @@ def get_adapter_multiplier():
17981805
if self.train_config.diff_output_preservation:
17991806
prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
18001807

1808+
if self.train_config.blank_prompt_preservation:
1809+
blank_embeds = self.cached_blank_embeds.clone().detach().to(
1810+
self.device_torch, dtype=dtype
1811+
)
1812+
prior_embeds_to_use = concat_prompt_embeds(
1813+
[blank_embeds] * noisy_latents.shape[0]
1814+
)
1815+
18011816
prior_pred = self.get_prior_prediction(
18021817
noisy_latents=noisy_latents,
18031818
conditional_embeds=prior_embeds_to_use,
@@ -1973,7 +1988,8 @@ def get_adapter_multiplier():
19731988
prior_to_calculate_loss = prior_pred
19741989
# if we are doing diff_output_preservation and not noing inverted masked prior
19751990
# then we need to send none here so it will not target the prior
1976-
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
1991+
doing_preservation = self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation
1992+
if doing_preservation and not do_inverted_masked_prior:
19771993
prior_to_calculate_loss = None
19781994

19791995
loss = self.calculate_loss(
@@ -1986,24 +2002,34 @@ def get_adapter_multiplier():
19862002
prior_pred=prior_to_calculate_loss,
19872003
)
19882004

1989-
if self.train_config.diff_output_preservation:
2005+
if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation:
19902006
# send the loss backwards otherwise checkpointing will fail
19912007
self.accelerator.backward(loss)
19922008
normal_loss = loss.detach() # dont send backward again
19932009

1994-
dop_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
1995-
dop_pred = self.predict_noise(
2010+
with torch.no_grad():
2011+
if self.train_config.diff_output_preservation:
2012+
preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
2013+
elif self.train_config.blank_prompt_preservation:
2014+
blank_embeds = self.cached_blank_embeds.clone().detach().to(
2015+
self.device_torch, dtype=dtype
2016+
)
2017+
preservation_embeds = concat_prompt_embeds(
2018+
[blank_embeds] * noisy_latents.shape[0]
2019+
)
2020+
preservation_pred = self.predict_noise(
19962021
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
19972022
timesteps=timesteps,
1998-
conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype),
2023+
conditional_embeds=preservation_embeds.to(self.device_torch, dtype=dtype),
19992024
unconditional_embeds=unconditional_embeds,
20002025
batch=batch,
20012026
**pred_kwargs
20022027
)
2003-
dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier
2004-
self.accelerator.backward(dop_loss)
2005-
2006-
loss = normal_loss + dop_loss
2028+
multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier
2029+
preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier
2030+
self.accelerator.backward(preservation_loss)
2031+
2032+
loss = normal_loss + preservation_loss
20072033
loss = loss.clone().detach()
20082034
# require grad again so the backward wont fail
20092035
loss.requires_grad_(True)

toolkit/config_modules.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,11 @@ def __init__(self, **kwargs):
457457
self.diff_output_preservation_multiplier = kwargs.get('diff_output_preservation_multiplier', 1.0)
458458
# If the trigger word is in the prompt, we will use this class name to replace it eg. "sks woman" -> "woman"
459459
self.diff_output_preservation_class = kwargs.get('diff_output_preservation_class', '')
460-
460+
461+
# blank prompt preservation will preserve the model's knowledge of a blank prompt
462+
self.blank_prompt_preservation = kwargs.get('blank_prompt_preservation', False)
463+
self.blank_prompt_preservation_multiplier = kwargs.get('blank_prompt_preservation_multiplier', 1.0)
464+
461465
# legacy
462466
if match_adapter_assist and self.match_adapter_chance == 0.0:
463467
self.match_adapter_chance = 1.0
@@ -1325,5 +1329,8 @@ def validate_configs(
13251329
if model_config.arch == 'qwen_image_edit':
13261330
if train_config.unload_text_encoder:
13271331
raise ValueError("Cannot cache unload text encoder with qwen_image_edit model. Control images are encoded with text embeddings. You can cache the text embeddings though")
1332+
1333+
if train_config.diff_output_preservation and train_config.blank_prompt_preservation:
1334+
raise ValueError("Cannot use both differential output preservation and blank prompt preservation at the same time. Please set one of them to False.")
13281335

13291336

toolkit/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def get_optimizer(
9393
optimizer_params['scale_parameter'] = False
9494
if 'warmup_init' not in optimizer_params:
9595
optimizer_params['warmup_init'] = False
96-
optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
96+
optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params)
9797
elif lower_type == 'automagic':
9898
from toolkit.optimizers.automagic import Automagic
9999
optimizer = Automagic(params, lr=float(learning_rate), **optimizer_params)

ui/src/app/jobs/new/SimpleJob.tsx

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,12 @@ export default function SimpleJob({
215215
</FormGroup>
216216
)}
217217
{modelArch?.additionalSections?.includes('model.qie.match_target_res') && (
218-
<Checkbox
219-
label="Match Target Res"
220-
docKey="model.qie.match_target_res"
221-
checked={jobConfig.config.process[0].model.model_kwargs.match_target_res}
222-
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.match_target_res')}
223-
/>
218+
<Checkbox
219+
label="Match Target Res"
220+
docKey="model.qie.match_target_res"
221+
checked={jobConfig.config.process[0].model.model_kwargs.match_target_res}
222+
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.match_target_res')}
223+
/>
224224
)}
225225
{modelArch?.additionalSections?.includes('model.layer_offloading') && (
226226
<>
@@ -586,16 +586,27 @@ export default function SimpleJob({
586586
</FormGroup>
587587
</div>
588588
<div>
589+
{disableSections.includes('train.diff_output_preservation') ||
590+
disableSections.includes('train.blank_prompt_preservation') ? null : (
591+
<FormGroup label="Regularization">
592+
<></>
593+
</FormGroup>
594+
)}
589595
{disableSections.includes('train.diff_output_preservation') ? null : (
590596
<>
591-
<FormGroup label="Regularization">
592-
<Checkbox
593-
label="Differential Output Preservation"
594-
className="pt-1"
595-
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
596-
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
597-
/>
598-
</FormGroup>
597+
<Checkbox
598+
label="Differential Output Preservation"
599+
docKey={'train.diff_output_preservation'}
600+
className="pt-1"
601+
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
602+
onChange={value => {
603+
setJobConfig(value, 'config.process[0].train.diff_output_preservation');
604+
if (value && jobConfig.config.process[0].train.blank_prompt_preservation) {
605+
// only one can be enabled at a time
606+
setJobConfig(false, 'config.process[0].train.blank_prompt_preservation');
607+
}
608+
}}
609+
/>
599610
{jobConfig.config.process[0].train.diff_output_preservation && (
600611
<>
601612
<NumberInput
@@ -610,7 +621,7 @@ export default function SimpleJob({
610621
/>
611622
<TextInput
612623
label="DOP Preservation Class"
613-
className="pt-2"
624+
className="pt-2 pb-4"
614625
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
615626
onChange={value =>
616627
setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')
@@ -621,6 +632,39 @@ export default function SimpleJob({
621632
)}
622633
</>
623634
)}
635+
{disableSections.includes('train.blank_prompt_preservation') ? null : (
636+
<>
637+
<Checkbox
638+
label="Blank Prompt Preservation"
639+
docKey={'train.blank_prompt_preservation'}
640+
className="pt-1"
641+
checked={jobConfig.config.process[0].train.blank_prompt_preservation || false}
642+
onChange={value => {
643+
setJobConfig(value, 'config.process[0].train.blank_prompt_preservation');
644+
if (value && jobConfig.config.process[0].train.diff_output_preservation) {
645+
// only one can be enabled at a time
646+
setJobConfig(false, 'config.process[0].train.diff_output_preservation');
647+
}
648+
}}
649+
/>
650+
{jobConfig.config.process[0].train.blank_prompt_preservation && (
651+
<>
652+
<NumberInput
653+
label="BPP Loss Multiplier"
654+
className="pt-2"
655+
value={
656+
(jobConfig.config.process[0].train.blank_prompt_preservation_multiplier as number) || 1.0
657+
}
658+
onChange={value =>
659+
setJobConfig(value, 'config.process[0].train.blank_prompt_preservation_multiplier')
660+
}
661+
placeholder="eg. 1.0"
662+
min={0}
663+
/>
664+
</>
665+
)}
666+
</>
667+
)}
624668
</div>
625669
</div>
626670
</Card>

ui/src/app/jobs/new/options.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ type DisableableSections =
99
| 'network.conv'
1010
| 'trigger_word'
1111
| 'train.diff_output_preservation'
12+
| 'train.blank_prompt_preservation'
1213
| 'train.unload_text_encoder'
1314
| 'slider';
1415

ui/src/components/SampleImages.tsx

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,7 @@ export default function SampleImages({ job }: SampleImagesProps) {
154154

155155
switch (cols) {
156156
case 1:
157-
return 'grid-cols-1';
158157
case 2:
159-
return 'grid-cols-2';
160158
case 3:
161159
return 'grid-cols-3';
162160
case 4:
@@ -234,7 +232,7 @@ export default function SampleImages({ job }: SampleImagesProps) {
234232
case 40:
235233
return 'grid-cols-40';
236234
default:
237-
return 'grid-cols-1';
235+
return 'grid-cols-3';
238236
}
239237
}, [numSamples]);
240238

@@ -262,17 +260,38 @@ export default function SampleImages({ job }: SampleImagesProps) {
262260
{PageInfoContent}
263261
{sampleImages && (
264262
<div className={`grid ${gridColsClass} gap-1`}>
265-
{sampleImages.map((sample: string) => (
266-
<SampleImageCard
267-
key={sample}
268-
imageUrl={sample}
269-
numSamples={numSamples}
270-
sampleImages={sampleImages}
271-
alt="Sample Image"
272-
onClick={() => setSelectedSamplePath(sample)}
273-
observerRoot={containerRef.current}
274-
/>
275-
))}
263+
{sampleImages.map((sample: string, idx: number) => {
264+
// Compute current group (groups are size = numSamples)
265+
const groupIndex = Math.floor(idx / numSamples);
266+
const groupStart = groupIndex * numSamples;
267+
const groupEnd = Math.min(groupStart + numSamples, sampleImages.length);
268+
const groupSize = groupEnd - groupStart;
269+
const isEndOfGroup = idx === groupEnd - 1;
270+
271+
// Only enforce a MIN of 3 when the group's planned width is < 3
272+
const MIN_COLS = 3;
273+
const shouldPad = numSamples < MIN_COLS && groupSize < MIN_COLS;
274+
const padsNeeded = shouldPad ? MIN_COLS - groupSize : 0;
275+
276+
return (
277+
<div key={sample} className="contents">
278+
<SampleImageCard
279+
imageUrl={sample}
280+
numSamples={numSamples}
281+
sampleImages={sampleImages}
282+
alt="Sample Image"
283+
onClick={() => setSelectedSamplePath(sample)}
284+
observerRoot={containerRef.current}
285+
/>
286+
287+
{isEndOfGroup &&
288+
padsNeeded > 0 &&
289+
Array.from({ length: padsNeeded }).map((_, i) => (
290+
<div key={`pad-${groupIndex}-${i}`} className="invisible" />
291+
))}
292+
</div>
293+
);
294+
})}
276295
</div>
277296
)}
278297
</div>

0 commit comments

Comments
 (0)