Skip to content

Commit 96d877b

Browse files
committed
support separate LR for Text Encoder for SD1/2
1 parent e72020a commit 96d877b

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

fine_tune.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010

1111
from tqdm import tqdm
1212
import torch
13+
1314
try:
1415
import intel_extension_for_pytorch as ipex
16+
1517
if torch.xpu.is_available():
1618
from library.ipex import ipex_init
19+
1720
ipex_init()
1821
except Exception:
1922
pass
@@ -193,14 +196,20 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
193196

194197
for m in training_models:
195198
m.requires_grad_(True)
196-
params = []
197-
for m in training_models:
198-
params.extend(m.parameters())
199-
params_to_optimize = params
199+
200+
trainable_params = []
201+
if args.learning_rate_te is None or not args.train_text_encoder:
202+
for m in training_models:
203+
trainable_params.extend(m.parameters())
204+
else:
205+
trainable_params = [
206+
{"params": list(unet.parameters()), "lr": args.learning_rate},
207+
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
208+
]
200209

201210
# 学習に必要なクラスを準備する
202211
accelerator.print("prepare optimizer, data loader etc.")
203-
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
212+
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
204213

205214
# dataloaderを準備する
206215
# DataLoaderのプロセス数:0はメインプロセスになる
@@ -340,7 +349,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
340349
else:
341350
target = noise
342351

343-
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss,:
352+
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
344353
# do not mean over batch dimension for snr weight or scale v-pred loss
345354
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
346355
loss = loss.mean([1, 2, 3])
@@ -476,6 +485,12 @@ def setup_parser() -> argparse.ArgumentParser:
476485

477486
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
478487
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
488+
parser.add_argument(
489+
"--learning_rate_te",
490+
type=float,
491+
default=None,
492+
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
493+
)
479494

480495
return parser
481496

train_db.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111

1212
from tqdm import tqdm
1313
import torch
14+
1415
try:
1516
import intel_extension_for_pytorch as ipex
17+
1618
if torch.xpu.is_available():
1719
from library.ipex import ipex_init
20+
1821
ipex_init()
1922
except Exception:
2023
pass
@@ -164,11 +167,17 @@ def train(args):
164167
# 学習に必要なクラスを準備する
165168
accelerator.print("prepare optimizer, data loader etc.")
166169
if train_text_encoder:
167-
# wightout list, adamw8bit is crashed
168-
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
170+
if args.learning_rate_te is None:
171+
# wightout list, adamw8bit is crashed
172+
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
173+
else:
174+
trainable_params = [
175+
{"params": list(unet.parameters()), "lr": args.learning_rate},
176+
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
177+
]
169178
else:
170179
trainable_params = unet.parameters()
171-
180+
172181
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
173182

174183
# dataloaderを準備する
@@ -461,6 +470,12 @@ def setup_parser() -> argparse.ArgumentParser:
461470
config_util.add_config_arguments(parser)
462471
custom_train_functions.add_custom_train_arguments(parser)
463472

473+
parser.add_argument(
474+
"--learning_rate_te",
475+
type=float,
476+
default=None,
477+
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
478+
)
464479
parser.add_argument(
465480
"--no_token_padding",
466481
action="store_true",

0 commit comments

Comments
 (0)