Skip to content
Merged

Dev #1067

Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add network_multiplier for dataset and train LoRA
  • Loading branch information
kohya-ss committed Jan 20, 2024
commit fef172966fff02a5f918840843eb613ee1a6d50e
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,42 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
- If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version.
- The sample image generation during training consumes a lot of memory. It is recommended to turn it off.

- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc.
- This is an experimental option and may be removed or changed in the future.
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.

- (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
- `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。
- PyTorch 2.1 以降が必要です。
- PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。
- 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。
- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。
- 実験的オプションのため、将来的に削除または仕様変更される可能性があります。
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
- また、五段階の状態を用意し、それぞれ `0.2`、`0.4`、`0.6`、`0.8`、`1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
- `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。

- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例

```toml
[general]
[[datasets]]
resolution = 512
batch_size = 8
network_multiplier = 1.0

... subset settings ...

[[datasets]]
resolution = 512
batch_size = 8
network_multiplier = -1.0

... subset settings ...
```


### Jan 17, 2024 / 2024/1/17: v0.8.1

Expand Down
3 changes: 3 additions & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class BaseDatasetParams:
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
max_token_length: int = None
resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0
debug_dataset: bool = False


Expand Down Expand Up @@ -219,6 +220,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"max_bucket_reso": int,
"min_bucket_reso": int,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
}

# options handled by argparse but not handled by user config
Expand Down Expand Up @@ -469,6 +471,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)

Expand Down
45 changes: 27 additions & 18 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def __init__(
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
max_token_length: int,
resolution: Optional[Tuple[int, int]],
network_multiplier: float,
debug_dataset: bool,
) -> None:
super().__init__()
Expand All @@ -567,6 +568,7 @@ def __init__(
self.max_token_length = max_token_length
# width/height is used when enable_bucket==False
self.width, self.height = (None, None) if resolution is None else resolution
self.network_multiplier = network_multiplier
self.debug_dataset = debug_dataset

self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
Expand Down Expand Up @@ -1106,7 +1108,9 @@ def __getitem__(self, index):
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
loss_weights.append(
self.prior_loss_weight if image_info.is_reg else 1.0
) # in case of fine tuning, is_reg is always False

flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance

Expand Down Expand Up @@ -1272,6 +1276,8 @@ def __getitem__(self, index):
example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw])
example["flippeds"] = flippeds

example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))

if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
return example
Expand Down Expand Up @@ -1346,15 +1352,16 @@ def __init__(
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset,
debug_dataset: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)

assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"

Expand Down Expand Up @@ -1520,14 +1527,15 @@ def __init__(
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset,
debug_dataset: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)

self.batch_size = batch_size

Expand Down Expand Up @@ -1724,14 +1732,15 @@ def __init__(
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset,
debug_dataset: float,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)

db_subsets = []
for subset in subsets:
Expand Down Expand Up @@ -2039,6 +2048,8 @@ def debug_dataset(train_dataset, show_input_ids=False):
print(
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}'
)
if "network_multipliers" in example:
print(f"network multiplier: {example['network_multipliers'][j]}")

if show_input_ids:
print(f"input ids: {iid}")
Expand Down Expand Up @@ -2105,8 +2116,8 @@ def glob_images_pathlib(dir_path, recursive):


class MinimalDataset(BaseDataset):
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)

self.num_train_images = 0 # update in subclass
self.num_reg_images = 0 # update in subclass
Expand Down Expand Up @@ -2850,14 +2861,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
)
parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う")
parser.add_argument(
"--dynamo_backend",
type=str,
default="inductor",
"--dynamo_backend",
type=str,
default="inductor",
# available backends:
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
# https://pytorch.org/docs/stable/torch.compiler.html
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)"
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)",
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument(
Expand Down Expand Up @@ -2904,9 +2915,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument(
"--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う"
)
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
parser.add_argument(
"--ddp_timeout",
type=int,
Expand Down Expand Up @@ -3889,7 +3898,7 @@ def prepare_accelerator(args: argparse.Namespace):
os.environ["WANDB_DIR"] = logging_dir
if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key)

# torch.compile のオプション。 NO の場合は torch.compile は使わない
dynamo_backend = "NO"
if args.torch_compile:
Expand Down
13 changes: 12 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def train(self, args):
)
if network is None:
return
network_has_multiplier = hasattr(network, "set_multiplier")

if hasattr(network, "prepare_network"):
network.prepare_network(args)
Expand Down Expand Up @@ -768,7 +769,17 @@ def remove_model(old_ckpt_name):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]

# get multiplier for each sample
if network_has_multiplier:
multipliers = batch["network_multipliers"]
# if all multipliers are same, use single multiplier
if torch.all(multipliers == multipliers[0]):
multipliers = multipliers[0].item()
else:
raise NotImplementedError("multipliers for each sample is not supported yet")
# print(f"set multiplier: {multipliers}")
network.set_multiplier(multipliers)

with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
Expand Down