Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/integration_tests/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ def run_single_test(test_flavor: OverrideDefinitions, output_dir: str):
dump_folder_arg = f"--dump_folder {output_dir}/{test_name}"

# Random init encoder for offline testing
random_init_encoder_arg = "--encoder.test_mode --dataloader.encoder.test_mode"
random_init_arg = "--tokenizer.test_mode --encoder.random_init"
clip_encoder_version_arg = (
"--encoder.clip_encoder tests/assets/flux_test_encoders/clip-vit-large-patch14/"
)
t5_encoder_version_arg = (
"--encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/"
)
t5_tokenizer_path_arg = "--tokenizer.t5_tokenizer_path tests/assets/tokenizer"
clip_tokenizer_path_arg = "--tokenizer.clip_tokenizer_path tests/assets/tokenizer"
hf_assets_path_arg = "--hf_assets_path tests/assets/tokenizer"

all_ranks = ",".join(map(str, range(test_flavor.ngpu)))
Expand All @@ -78,9 +80,11 @@ def run_single_test(test_flavor: OverrideDefinitions, output_dir: str):
)

cmd += " " + dump_folder_arg
cmd += " " + random_init_encoder_arg
cmd += " " + random_init_arg
cmd += " " + clip_encoder_version_arg
cmd += " " + t5_encoder_version_arg
cmd += " " + t5_tokenizer_path_arg
cmd += " " + clip_tokenizer_path_arg
cmd += " " + hf_assets_path_arg
if override_arg:
cmd += " " + " ".join(override_arg)
Expand Down
15 changes: 12 additions & 3 deletions tests/unit_tests/test_dataset_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,29 @@ def test_load_dataset(self):
str(256),
"--dataloader.dataset",
dataset_name,
"--dataloader.classifier_free_guidance_prob",
"--dataloader.prompt_dropout_prob",
"0.447",
"--dataloader.encoder.test_mode",
"--encoder.test_mode",
"--tokenizer.test_mode",
"--tokenizer.t5_tokenizer_path",
"tests/assets/tokenizer",
"--tokenizer.clip_tokenizer_path",
"tests/assets/tokenizer",
"--encoder.random_init",
"--encoder.t5_encoder",
"tests/assets/flux_test_encoders/t5-v1_1-xxl",
"--encoder.clip_encoder",
"tests/assets/flux_test_encoders/clip-vit-large-patch14",
]
)

# Build the tokenizer container from config
tokenizer = config.tokenizer.build(tokenizer_path=config.hf_assets_path)

dl = config.dataloader.build(
dp_world_size=world_size,
dp_rank=rank,
local_batch_size=batch_size,
tokenizer=tokenizer,
)

it = iter(dl)
Expand Down Expand Up @@ -107,6 +115,7 @@ def test_load_dataset(self):
dp_world_size=world_size,
dp_rank=rank,
local_batch_size=batch_size,
tokenizer=tokenizer,
)
dl_resumed.load_state_dict(state)
it_resumed = iter(dl_resumed)
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/flux/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ This step will download the autoencoder model from HuggingFace and save it to th

Run the following command to train the debug model on a single GPU:
```bash
MODULE=flux CONFIG=flux_debugmodel .run_train.sh
MODULE=flux CONFIG=flux_debugmodel ./run_train.sh
```

If you want to train with other configs, run the following command:
Expand Down
94 changes: 47 additions & 47 deletions torchtitan/models/flux/config_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,29 @@
ParallelismConfig,
TrainingConfig,
)
from torchtitan.models.flux.configs import Encoder, Inference, Validation
from torchtitan.models.flux.configs import FluxEncoderConfig, Inference
from torchtitan.models.flux.flux_datasets import FluxDataLoader
from torchtitan.models.flux.tokenizer import FluxTokenizerContainer
from torchtitan.models.flux.trainer import FluxTrainer
from torchtitan.models.flux.validate import FluxValidator

from . import model_registry


def flux_debugmodel() -> FluxTrainer.Config:
encoder = Encoder(
t5_encoder="google/t5-v1_1-xxl",
clip_encoder="openai/clip-vit-large-patch14",
max_t5_encoding_len=256,
autoencoder_path="assets/hf/FLUX.1-dev/ae.safetensors",
)
hf_assets_path = "tests/assets/tokenizer"
return FluxTrainer.Config(
hf_assets_path=hf_assets_path,
tokenizer=FluxTokenizerContainer.Config(
t5_tokenizer_path="google/t5-v1_1-xxl",
clip_tokenizer_path="openai/clip-vit-large-patch14",
max_t5_encoding_len=256,
),
encoder=FluxEncoderConfig(
t5_encoder="google/t5-v1_1-xxl",
clip_encoder="openai/clip-vit-large-patch14",
autoencoder_path="assets/hf/FLUX.1-dev/ae.safetensors",
),
metrics=MetricsProcessor.Config(log_freq=1),
model_spec=model_registry("flux-debug"),
optimizer=OptimizersContainer.Config(lr=8e-4),
Expand All @@ -44,29 +49,24 @@ def flux_debugmodel() -> FluxTrainer.Config:
steps=10,
),
dataloader=FluxDataLoader.Config(
classifier_free_guidance_prob=0.447,
prompt_dropout_prob=0.447,
img_size=256,
encoder=encoder,
hf_assets_path=hf_assets_path,
),
encoder=encoder,
parallelism=ParallelismConfig(context_parallel_degree=1),
activation_checkpoint=ActivationCheckpointConfig(mode="full"),
checkpoint=CheckpointManager.Config(
interval=10,
last_save_model_only=False,
),
validation=Validation(
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=4,
),
validator=FluxValidator.Config(
freq=5,
steps=48,
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=4,
dataloader=FluxDataLoader.Config(
dataset="coco-validation",
classifier_free_guidance_prob=0.447,
prompt_dropout_prob=0.0,
img_size=256,
generate_timesteps=True,
),
Expand All @@ -83,13 +83,17 @@ def flux_debugmodel() -> FluxTrainer.Config:


def flux_dev() -> FluxTrainer.Config:
encoder = Encoder(
t5_encoder="google/t5-v1_1-xxl",
clip_encoder="openai/clip-vit-large-patch14",
max_t5_encoding_len=512,
autoencoder_path="assets/hf/FLUX.1-dev/ae.safetensors",
)
return FluxTrainer.Config(
tokenizer=FluxTokenizerContainer.Config(
t5_tokenizer_path="google/t5-v1_1-xxl",
clip_tokenizer_path="openai/clip-vit-large-patch14",
max_t5_encoding_len=512,
),
encoder=FluxEncoderConfig(
t5_encoder="google/t5-v1_1-xxl",
clip_encoder="openai/clip-vit-large-patch14",
autoencoder_path="assets/hf/FLUX.1-dev/ae.safetensors",
),
metrics=MetricsProcessor.Config(log_freq=100),
model_spec=model_registry("flux-dev"),
optimizer=OptimizersContainer.Config(lr=1e-4),
Expand All @@ -103,24 +107,20 @@ def flux_dev() -> FluxTrainer.Config:
),
dataloader=FluxDataLoader.Config(
dataset="cc12m-wds",
classifier_free_guidance_prob=0.447,
prompt_dropout_prob=0.447,
img_size=256,
encoder=encoder,
),
encoder=encoder,
activation_checkpoint=ActivationCheckpointConfig(mode="full"),
checkpoint=CheckpointManager.Config(interval=1000),
validation=Validation(
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=50,
),
validator=FluxValidator.Config(
freq=1000,
steps=12,
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=50,
dataloader=FluxDataLoader.Config(
dataset="coco-validation",
classifier_free_guidance_prob=0,
prompt_dropout_prob=0,
img_size=256,
generate_timesteps=True,
),
Expand All @@ -132,13 +132,17 @@ def flux_dev() -> FluxTrainer.Config:


def flux_schnell() -> FluxTrainer.Config:
encoder = Encoder(
t5_encoder="google/t5-v1_1-xxl",
clip_encoder="openai/clip-vit-large-patch14",
max_t5_encoding_len=256,
autoencoder_path="assets/hf/FLUX.1-dev/ae.safetensors",
)
return FluxTrainer.Config(
tokenizer=FluxTokenizerContainer.Config(
t5_tokenizer_path="google/t5-v1_1-xxl",
clip_tokenizer_path="openai/clip-vit-large-patch14",
max_t5_encoding_len=256,
),
encoder=FluxEncoderConfig(
t5_encoder="google/t5-v1_1-xxl",
clip_encoder="openai/clip-vit-large-patch14",
autoencoder_path="assets/hf/FLUX.1-dev/ae.safetensors",
),
metrics=MetricsProcessor.Config(log_freq=100),
model_spec=model_registry("flux-schnell"),
optimizer=OptimizersContainer.Config(lr=1e-4),
Expand All @@ -152,24 +156,20 @@ def flux_schnell() -> FluxTrainer.Config:
),
dataloader=FluxDataLoader.Config(
dataset="cc12m-wds",
classifier_free_guidance_prob=0.447,
prompt_dropout_prob=0.447,
img_size=256,
encoder=encoder,
),
encoder=encoder,
activation_checkpoint=ActivationCheckpointConfig(mode="full"),
checkpoint=CheckpointManager.Config(interval=1000),
validation=Validation(
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=50,
),
validator=FluxValidator.Config(
freq=1000,
steps=6,
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=50,
dataloader=FluxDataLoader.Config(
dataset="coco-validation",
classifier_free_guidance_prob=0,
prompt_dropout_prob=0,
img_size=256,
generate_timesteps=True,
),
Expand Down
36 changes: 13 additions & 23 deletions torchtitan/models/flux/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,19 @@


@dataclass(kw_only=True, slots=True)
class Encoder:
class FluxEncoderConfig:
"""Configuration for Flux encoders (T5 text encoder, CLIP text encoder, and autoencoder)."""

t5_encoder: str = "google/t5-v1_1-small"
"""T5 encoder to use, HuggingFace model name. This field could be either a local folder path,
or a Huggingface repo name."""
"""HuggingFace model name or local path for the T5 text encoder."""
clip_encoder: str = "openai/clip-vit-large-patch14"
"""Clip encoder to use, HuggingFace model name. This field could be either a local folder path,
or a Huggingface repo name."""
"""HuggingFace model name or local path for the CLIP text encoder."""
autoencoder_path: str = (
"torchtitan/experiments/flux/assets/autoencoder/ae.safetensors"
)
"""Autoencoder checkpoint path to load. This should be a local path referring to a safetensors file."""
max_t5_encoding_len: int = 256
"""Maximum length of the T5 encoding."""

test_mode: bool = False
"""Whether to use integration test mode, which will randomly initialize the encoder and use a dummy tokenizer"""


# TODO: maybe consolidate with FluxValidator.Config
@dataclass(kw_only=True, slots=True)
class Validation:
enable_classifier_free_guidance: bool = False
"""Whether to use classifier-free guidance during sampling"""
classifier_free_guidance_scale: float = 5.0
"""Classifier-free guidance scale when sampling"""
denoising_steps: int = 50
"""How many denoising steps to sample when generating an image"""
eval_freq: int = 100
"""Frequency of evaluation/sampling during training"""
random_init: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
random_init: bool = False
_random_init: bool = False

since it's for test only

"""If True, initialize encoders with random weights instead of loading pretrained weights (for testing only)."""


@dataclass(kw_only=True, slots=True)
Expand All @@ -51,3 +35,9 @@ class Inference:
"""Batch size for inference"""
img_size: int = 256
"""Image size for inference"""
enable_classifier_free_guidance: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we are putting these configs in both validator.config and inference.config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to a shared config as claude suggested, these duplicated fields are used for generating / sampling a image in both validation and inference. h

"""Whether to use classifier-free guidance during sampling"""
classifier_free_guidance_scale: float = 5.0
"""Classifier-free guidance scale when sampling"""
denoising_steps: int = 50
"""How many denoising steps to sample when generating an image"""
Loading
Loading