|
10 | 10 |
|
11 | 11 | from tqdm import tqdm |
12 | 12 | import torch |
| 13 | + |
13 | 14 | try: |
14 | 15 | import intel_extension_for_pytorch as ipex |
| 16 | + |
15 | 17 | if torch.xpu.is_available(): |
16 | 18 | from library.ipex import ipex_init |
| 19 | + |
17 | 20 | ipex_init() |
18 | 21 | except Exception: |
19 | 22 | pass |
@@ -193,14 +196,20 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): |
193 | 196 |
|
194 | 197 | for m in training_models: |
195 | 198 | m.requires_grad_(True) |
196 | | - params = [] |
197 | | - for m in training_models: |
198 | | - params.extend(m.parameters()) |
199 | | - params_to_optimize = params |
| 199 | + |
| 200 | + trainable_params = [] |
| 201 | + if args.learning_rate_te is None or not args.train_text_encoder: |
| 202 | + for m in training_models: |
| 203 | + trainable_params.extend(m.parameters()) |
| 204 | + else: |
| 205 | + trainable_params = [ |
| 206 | + {"params": list(unet.parameters()), "lr": args.learning_rate}, |
| 207 | + {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te}, |
| 208 | + ] |
200 | 209 |
|
201 | 210 | # 学習に必要なクラスを準備する |
202 | 211 | accelerator.print("prepare optimizer, data loader etc.") |
203 | | - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) |
| 212 | + _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) |
204 | 213 |
|
205 | 214 | # dataloaderを準備する |
206 | 215 | # DataLoaderのプロセス数:0はメインプロセスになる |
@@ -340,7 +349,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): |
340 | 349 | else: |
341 | 350 | target = noise |
342 | 351 |
|
343 | | - if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss,: |
| 352 | + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: |
344 | 353 | # do not mean over batch dimension for snr weight or scale v-pred loss |
345 | 354 | loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") |
346 | 355 | loss = loss.mean([1, 2, 3]) |
@@ -476,6 +485,12 @@ def setup_parser() -> argparse.ArgumentParser: |
476 | 485 |
|
477 | 486 | parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") |
478 | 487 | parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") |
| 488 | + parser.add_argument( |
| 489 | + "--learning_rate_te", |
| 490 | + type=float, |
| 491 | + default=None, |
| 492 | + help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", |
| 493 | + ) |
479 | 494 |
|
480 | 495 | return parser |
481 | 496 |
|
|
0 commit comments