Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/integration_tests/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ def run_single_test(test_flavor: OverrideDefinitions, output_dir: str):
dump_folder_arg = f"--dump_folder {output_dir}/{test_name}"

# Random init encoder for offline testing
random_init_encoder_arg = "--encoder.test_mode --dataloader.encoder.test_mode"
random_init_arg = "--tokenizer.test_mode --encoder._random_init"
clip_encoder_version_arg = (
"--encoder.clip_encoder tests/assets/flux_test_encoders/clip-vit-large-patch14/"
)
t5_encoder_version_arg = (
"--encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/"
)
t5_tokenizer_path_arg = "--tokenizer.t5_tokenizer_path tests/assets/tokenizer"
clip_tokenizer_path_arg = "--tokenizer.clip_tokenizer_path tests/assets/tokenizer"
hf_assets_path_arg = "--hf_assets_path tests/assets/tokenizer"

all_ranks = ",".join(map(str, range(test_flavor.ngpu)))
Expand All @@ -78,9 +80,11 @@ def run_single_test(test_flavor: OverrideDefinitions, output_dir: str):
)

cmd += " " + dump_folder_arg
cmd += " " + random_init_encoder_arg
cmd += " " + random_init_arg
cmd += " " + clip_encoder_version_arg
cmd += " " + t5_encoder_version_arg
cmd += " " + t5_tokenizer_path_arg
cmd += " " + clip_tokenizer_path_arg
cmd += " " + hf_assets_path_arg
if override_arg:
cmd += " " + " ".join(override_arg)
Expand Down
15 changes: 12 additions & 3 deletions tests/unit_tests/test_dataset_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,29 @@ def test_load_dataset(self):
str(256),
"--dataloader.dataset",
dataset_name,
"--dataloader.classifier_free_guidance_prob",
"--dataloader.prompt_dropout_prob",
"0.447",
"--dataloader.encoder.test_mode",
"--encoder.test_mode",
"--tokenizer.test_mode",
"--tokenizer.t5_tokenizer_path",
"tests/assets/tokenizer",
"--tokenizer.clip_tokenizer_path",
"tests/assets/tokenizer",
"--encoder._random_init",
"--encoder.t5_encoder",
"tests/assets/flux_test_encoders/t5-v1_1-xxl",
"--encoder.clip_encoder",
"tests/assets/flux_test_encoders/clip-vit-large-patch14",
]
)

# Build the tokenizer container from config
tokenizer = config.tokenizer.build(tokenizer_path=config.hf_assets_path)

dl = config.dataloader.build(
dp_world_size=world_size,
dp_rank=rank,
local_batch_size=batch_size,
tokenizer=tokenizer,
)

it = iter(dl)
Expand Down Expand Up @@ -107,6 +115,7 @@ def test_load_dataset(self):
dp_world_size=world_size,
dp_rank=rank,
local_batch_size=batch_size,
tokenizer=tokenizer,
)
dl_resumed.load_state_dict(state)
it_resumed = iter(dl_resumed)
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/flux/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ This step will download the autoencoder model from HuggingFace and save it to th

Run the following command to train the debug model on a single GPU:
```bash
MODULE=flux CONFIG=flux_debugmodel .run_train.sh
MODULE=flux CONFIG=flux_debugmodel ./run_train.sh
```

If you want to train with other configs, run the following command:
Expand Down
100 changes: 53 additions & 47 deletions torchtitan/models/flux/config_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,29 @@
ParallelismConfig,
TrainingConfig,
)
from torchtitan.models.flux.configs import Encoder, Inference, Validation
from torchtitan.models.flux.configs import FluxEncoderConfig, Inference, SamplingConfig
from torchtitan.models.flux.flux_datasets import FluxDataLoader
from torchtitan.models.flux.tokenizer import FluxTokenizerContainer
from torchtitan.models.flux.trainer import FluxTrainer
from torchtitan.models.flux.validate import FluxValidator

from . import model_registry


def flux_debugmodel() -> FluxTrainer.Config:
encoder = Encoder(
t5_encoder="google/t5-v1_1-xxl",
clip_encoder="openai/clip-vit-large-patch14",
max_t5_encoding_len=256,
autoencoder_path="assets/hf/FLUX.1-dev/ae.safetensors",
)
hf_assets_path = "tests/assets/tokenizer"
return FluxTrainer.Config(
hf_assets_path=hf_assets_path,
tokenizer=FluxTokenizerContainer.Config(
t5_tokenizer_path="google/t5-v1_1-xxl",
clip_tokenizer_path="openai/clip-vit-large-patch14",
max_t5_encoding_len=256,
),
encoder=FluxEncoderConfig(
t5_encoder="google/t5-v1_1-xxl",
clip_encoder="openai/clip-vit-large-patch14",
autoencoder_path="assets/hf/FLUX.1-dev/ae.safetensors",
),
metrics=MetricsProcessor.Config(log_freq=1),
model_spec=model_registry("flux-debug"),
optimizer=OptimizersContainer.Config(lr=8e-4),
Expand All @@ -44,29 +49,26 @@ def flux_debugmodel() -> FluxTrainer.Config:
steps=10,
),
dataloader=FluxDataLoader.Config(
classifier_free_guidance_prob=0.447,
prompt_dropout_prob=0.447,
img_size=256,
encoder=encoder,
hf_assets_path=hf_assets_path,
),
encoder=encoder,
parallelism=ParallelismConfig(context_parallel_degree=1),
activation_checkpoint=ActivationCheckpointConfig(mode="full"),
checkpoint=CheckpointManager.Config(
interval=10,
last_save_model_only=False,
),
validation=Validation(
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=4,
),
validator=FluxValidator.Config(
freq=5,
steps=48,
sampling=SamplingConfig(
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=4,
),
dataloader=FluxDataLoader.Config(
dataset="coco-validation",
classifier_free_guidance_prob=0.447,
prompt_dropout_prob=0.0,
img_size=256,
generate_timesteps=True,
),
Expand All @@ -83,13 +85,17 @@ def flux_debugmodel() -> FluxTrainer.Config:


def flux_dev() -> FluxTrainer.Config:
encoder = Encoder(
t5_encoder="google/t5-v1_1-xxl",
clip_encoder="openai/clip-vit-large-patch14",
max_t5_encoding_len=512,
autoencoder_path="assets/hf/FLUX.1-dev/ae.safetensors",
)
return FluxTrainer.Config(
tokenizer=FluxTokenizerContainer.Config(
t5_tokenizer_path="google/t5-v1_1-xxl",
clip_tokenizer_path="openai/clip-vit-large-patch14",
max_t5_encoding_len=512,
),
encoder=FluxEncoderConfig(
t5_encoder="google/t5-v1_1-xxl",
clip_encoder="openai/clip-vit-large-patch14",
autoencoder_path="assets/hf/FLUX.1-dev/ae.safetensors",
),
metrics=MetricsProcessor.Config(log_freq=100),
model_spec=model_registry("flux-dev"),
optimizer=OptimizersContainer.Config(lr=1e-4),
Expand All @@ -103,24 +109,22 @@ def flux_dev() -> FluxTrainer.Config:
),
dataloader=FluxDataLoader.Config(
dataset="cc12m-wds",
classifier_free_guidance_prob=0.447,
prompt_dropout_prob=0.447,
img_size=256,
encoder=encoder,
),
encoder=encoder,
activation_checkpoint=ActivationCheckpointConfig(mode="full"),
checkpoint=CheckpointManager.Config(interval=1000),
validation=Validation(
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=50,
),
validator=FluxValidator.Config(
freq=1000,
steps=12,
sampling=SamplingConfig(
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=50,
),
dataloader=FluxDataLoader.Config(
dataset="coco-validation",
classifier_free_guidance_prob=0,
prompt_dropout_prob=0,
img_size=256,
generate_timesteps=True,
),
Expand All @@ -132,13 +136,17 @@ def flux_dev() -> FluxTrainer.Config:


def flux_schnell() -> FluxTrainer.Config:
encoder = Encoder(
t5_encoder="google/t5-v1_1-xxl",
clip_encoder="openai/clip-vit-large-patch14",
max_t5_encoding_len=256,
autoencoder_path="assets/hf/FLUX.1-dev/ae.safetensors",
)
return FluxTrainer.Config(
tokenizer=FluxTokenizerContainer.Config(
t5_tokenizer_path="google/t5-v1_1-xxl",
clip_tokenizer_path="openai/clip-vit-large-patch14",
max_t5_encoding_len=256,
),
encoder=FluxEncoderConfig(
t5_encoder="google/t5-v1_1-xxl",
clip_encoder="openai/clip-vit-large-patch14",
autoencoder_path="assets/hf/FLUX.1-dev/ae.safetensors",
),
metrics=MetricsProcessor.Config(log_freq=100),
model_spec=model_registry("flux-schnell"),
optimizer=OptimizersContainer.Config(lr=1e-4),
Expand All @@ -152,24 +160,22 @@ def flux_schnell() -> FluxTrainer.Config:
),
dataloader=FluxDataLoader.Config(
dataset="cc12m-wds",
classifier_free_guidance_prob=0.447,
prompt_dropout_prob=0.447,
img_size=256,
encoder=encoder,
),
encoder=encoder,
activation_checkpoint=ActivationCheckpointConfig(mode="full"),
checkpoint=CheckpointManager.Config(interval=1000),
validation=Validation(
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=50,
),
validator=FluxValidator.Config(
freq=1000,
steps=6,
sampling=SamplingConfig(
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=50,
),
dataloader=FluxDataLoader.Config(
dataset="coco-validation",
classifier_free_guidance_prob=0,
prompt_dropout_prob=0,
img_size=256,
generate_timesteps=True,
),
Expand Down
49 changes: 31 additions & 18 deletions torchtitan/models/flux/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,50 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from dataclasses import dataclass, field


@dataclass(kw_only=True, slots=True)
class Encoder:
class FluxEncoderConfig:
"""Configuration for Flux encoders (T5 text encoder, CLIP text encoder, and autoencoder)."""

t5_encoder: str = "google/t5-v1_1-small"
"""T5 encoder to use, HuggingFace model name. This field could be either a local folder path,
or a Huggingface repo name."""
"""HuggingFace model name or local path for the T5 text encoder."""
clip_encoder: str = "openai/clip-vit-large-patch14"
"""Clip encoder to use, HuggingFace model name. This field could be either a local folder path,
or a Huggingface repo name."""
"""HuggingFace model name or local path for the CLIP text encoder."""
autoencoder_path: str = (
"torchtitan/experiments/flux/assets/autoencoder/ae.safetensors"
)
"""Autoencoder checkpoint path to load. This should be a local path referring to a safetensors file."""
max_t5_encoding_len: int = 256
"""Maximum length of the T5 encoding."""

test_mode: bool = False
"""Whether to use integration test mode, which will randomly initialize the encoder and use a dummy tokenizer"""
_random_init: bool = False
"""If True, initialize encoders with random weights instead of loading pretrained weights (for testing only)."""


# TODO: maybe consolidate with FluxValidator.Config
@dataclass(kw_only=True, slots=True)
class Validation:
class SamplingConfig:
"""Shared configuration for image generation sampling (used by both validation and inference)."""

enable_classifier_free_guidance: bool = False
"""Whether to use classifier-free guidance during sampling"""
"""Whether to use classifier-free guidance (CFG) during image generation.

When enabled, the model runs two forward passes per denoising step — one with
the text prompt and one without — then interpolates the results using
`classifier_free_guidance_scale` to produce images that more closely follow
the prompt. This typically yields higher-quality, more prompt-adherent images
but doubles the compute cost per sampling step.
"""

classifier_free_guidance_scale: float = 5.0
"""Classifier-free guidance scale when sampling"""
"""Interpolation weight for classifier-free guidance during sampling.

Higher values steer the output more strongly toward the text prompt, producing
sharper and more prompt-adherent images, but may reduce diversity or introduce
artifacts. Typical values range from 1.0 (no guidance) to 10.0 (strong guidance).
Only takes effect when `enable_classifier_free_guidance` is True.
"""

denoising_steps: int = 50
"""How many denoising steps to sample when generating an image"""
eval_freq: int = 100
"""Frequency of evaluation/sampling during training"""
"""How many denoising steps to sample when generating an image."""


@dataclass(kw_only=True, slots=True)
Expand All @@ -51,3 +62,5 @@ class Inference:
"""Batch size for inference"""
img_size: int = 256
"""Image size for inference"""
sampling: SamplingConfig = field(default_factory=SamplingConfig)
"""Sampling configuration for image generation"""
Loading
Loading