Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor encode/decode
  • Loading branch information
wwwjn committed Mar 10, 2026
commit 12169d6f127153071323ad4ca0ee0cd44d69bf15
26 changes: 16 additions & 10 deletions torchtitan/models/flux/config_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ParallelismConfig,
TrainingConfig,
)
from torchtitan.models.flux.configs import FluxEncoderConfig, Inference
from torchtitan.models.flux.configs import FluxEncoderConfig, Inference, SamplingConfig
from torchtitan.models.flux.flux_datasets import FluxDataLoader
from torchtitan.models.flux.tokenizer import FluxTokenizerContainer
from torchtitan.models.flux.trainer import FluxTrainer
Expand Down Expand Up @@ -61,9 +61,11 @@ def flux_debugmodel() -> FluxTrainer.Config:
validator=FluxValidator.Config(
freq=5,
steps=48,
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=4,
sampling=SamplingConfig(
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=4,
),
dataloader=FluxDataLoader.Config(
dataset="coco-validation",
prompt_dropout_prob=0.0,
Expand Down Expand Up @@ -115,9 +117,11 @@ def flux_dev() -> FluxTrainer.Config:
validator=FluxValidator.Config(
freq=1000,
steps=12,
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=50,
sampling=SamplingConfig(
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=50,
),
dataloader=FluxDataLoader.Config(
dataset="coco-validation",
prompt_dropout_prob=0,
Expand Down Expand Up @@ -164,9 +168,11 @@ def flux_schnell() -> FluxTrainer.Config:
validator=FluxValidator.Config(
freq=1000,
steps=6,
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=50,
sampling=SamplingConfig(
enable_classifier_free_guidance=True,
classifier_free_guidance_scale=5.0,
denoising_steps=50,
),
dataloader=FluxDataLoader.Config(
dataset="coco-validation",
prompt_dropout_prob=0,
Expand Down
39 changes: 31 additions & 8 deletions torchtitan/models/flux/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from dataclasses import dataclass, field


@dataclass(kw_only=True, slots=True)
Expand All @@ -19,10 +19,37 @@ class FluxEncoderConfig:
"torchtitan/experiments/flux/assets/autoencoder/ae.safetensors"
)
"""Autoencoder checkpoint path to load. This should be a local path referring to a safetensors file."""
random_init: bool = False
_random_init: bool = False
"""If True, initialize encoders with random weights instead of loading pretrained weights (for testing only)."""


@dataclass(kw_only=True, slots=True)
class SamplingConfig:
"""Shared configuration for image generation sampling (used by both validation and inference)."""

enable_classifier_free_guidance: bool = False
"""Whether to use classifier-free guidance (CFG) during image generation.

When enabled, the model runs two forward passes per denoising step — one with
the text prompt and one without — then interpolates the results using
`classifier_free_guidance_scale` to produce images that more closely follow
the prompt. This typically yields higher-quality, more prompt-adherent images
but doubles the compute cost per sampling step.
"""

classifier_free_guidance_scale: float = 5.0
"""Interpolation weight for classifier-free guidance during sampling.

Higher values steer the output more strongly toward the text prompt, producing
sharper and more prompt-adherent images, but may reduce diversity or introduce
artifacts. Typical values range from 1.0 (no guidance) to 10.0 (strong guidance).
Only takes effect when `enable_classifier_free_guidance` is True.
"""

denoising_steps: int = 50
"""How many denoising steps to sample when generating an image."""


@dataclass(kw_only=True, slots=True)
class Inference:
"""Inference configuration"""
Expand All @@ -35,9 +62,5 @@ class Inference:
"""Batch size for inference"""
img_size: int = 256
"""Image size for inference"""
enable_classifier_free_guidance: bool = False
"""Whether to use classifier-free guidance during sampling"""
classifier_free_guidance_scale: float = 5.0
"""Classifier-free guidance scale when sampling"""
denoising_steps: int = 50
"""How many denoising steps to sample when generating an image"""
sampling: SamplingConfig = field(default_factory=SamplingConfig)
"""Sampling configuration for image generation"""
65 changes: 24 additions & 41 deletions torchtitan/models/flux/flux_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchtitan.components.dataloader import ParallelAwareDataloader
from torchtitan.components.tokenizer import BaseTokenizer
from torchtitan.hf_datasets import DatasetConfig
from torchtitan.models.flux.tokenizer import FluxTokenizer, FluxTokenizerContainer
from torchtitan.models.flux.tokenizer import FluxTokenizerContainer
from torchtitan.tools.logging import logger


Expand Down Expand Up @@ -77,60 +77,52 @@ def _process_cc12m_image(

def _cc12m_wds_data_processor(
sample: dict[str, Any],
t5_tokenizer: FluxTokenizer,
clip_tokenizer: FluxTokenizer,
tokenizer: FluxTokenizerContainer,
output_size: int = 256,
) -> dict[str, Any]:
"""
Preprocess CC12M dataset sample image and text for Flux model.

Args:
sample: A sample from dataset
t5_tokenizer: T5 tokenizer
clip_tokenizer: CLIP tokenizer
tokenizer: FluxTokenizerContainer that encodes text with both T5 and CLIP
output_size: The output image size

"""
img = _process_cc12m_image(sample["jpg"], output_size=output_size)
t5_tokens = t5_tokenizer.encode(sample["txt"])
clip_tokens = clip_tokenizer.encode(sample["txt"])
tokens = tokenizer.encode(sample["txt"])

return {
"image": img,
"clip_tokens": clip_tokens, # type: List[int]
"t5_tokens": t5_tokens, # type: List[int]
"prompt": sample["txt"], # type: str
**tokens,
"prompt": sample["txt"],
}


def _coco_data_processor(
sample: dict[str, Any],
t5_tokenizer: FluxTokenizer,
clip_tokenizer: FluxTokenizer,
tokenizer: FluxTokenizerContainer,
output_size: int = 256,
) -> dict[str, Any]:
"""
Preprocess COCO dataset sample image and text for Flux model.

Args:
sample: A sample from dataset
t5_tokenizer: T5 tokenizer
clip_tokenizer: CLIP tokenizer
tokenizer: FluxTokenizerContainer that encodes text with both T5 and CLIP
output_size: The output image size

"""
img = _process_cc12m_image(sample["image"], output_size=output_size)
prompt = sample["caption"]
if isinstance(prompt, list):
prompt = prompt[0]
t5_tokens = t5_tokenizer.encode(prompt)
clip_tokens = clip_tokenizer.encode(prompt)
tokens = tokenizer.encode(prompt)

return {
"image": img,
"clip_tokens": clip_tokens, # type: List[int]
"t5_tokens": t5_tokens, # type: List[int]
"prompt": prompt, # type: str
**tokens,
"prompt": prompt,
}


Expand Down Expand Up @@ -187,8 +179,7 @@ def __init__(
self,
dataset_name: str,
dataset_path: str | None,
t5_tokenizer: BaseTokenizer,
clip_tokenizer: BaseTokenizer,
tokenizer: FluxTokenizerContainer,
prompt_dropout_prob: float,
img_size: int,
dp_rank: int = 0,
Expand All @@ -207,10 +198,10 @@ def __init__(
self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)

self._t5_tokenizer = t5_tokenizer
self._t5_empty_token = t5_tokenizer.encode("")
self._clip_tokenizer = clip_tokenizer
self._clip_empty_token = clip_tokenizer.encode("")
self._tokenizer = tokenizer
empty_tokens = tokenizer.encode("")
self._t5_empty_token = empty_tokens["t5_tokens"]
self._clip_empty_token = empty_tokens["clip_tokens"]
self._data_processor = data_processor
self.prompt_dropout_prob = prompt_dropout_prob
self.img_size = img_size
Expand Down Expand Up @@ -263,8 +254,7 @@ def __iter__(self):
# Use the dataset-specific preprocessor
sample_dict = self._data_processor(
sample,
self._t5_tokenizer,
self._clip_tokenizer,
self._tokenizer,
output_size=self.img_size,
)

Expand Down Expand Up @@ -319,8 +309,7 @@ def __init__(
self,
dataset_name: str,
dataset_path: str | None,
t5_tokenizer: BaseTokenizer,
clip_tokenizer: BaseTokenizer,
tokenizer: FluxTokenizerContainer,
prompt_dropout_prob: float,
img_size: int,
dp_rank: int = 0,
Expand All @@ -332,8 +321,7 @@ def __init__(
super().__init__(
dataset_name=dataset_name,
dataset_path=dataset_path,
t5_tokenizer=t5_tokenizer,
clip_tokenizer=clip_tokenizer,
tokenizer=tokenizer,
prompt_dropout_prob=prompt_dropout_prob,
img_size=img_size,
dp_rank=dp_rank,
Expand Down Expand Up @@ -393,11 +381,10 @@ class Config(ParallelAwareDataloader.Config):

def __post_init__(self):
if self.generate_timesteps and self.prompt_dropout_prob != 0.0:
logger.warning(
f"prompt_dropout_prob={self.prompt_dropout_prob} "
"overridden to 0.0 for validation (generate_timesteps=True)."
raise ValueError(
f"prompt_dropout_prob must be 0.0 when generate_timesteps=True "
f"(for validation), but got {self.prompt_dropout_prob}."
)
self.prompt_dropout_prob = 0.0

def __init__(
self,
Expand All @@ -415,15 +402,12 @@ def __init__(
"FluxDataLoader requires a FluxTokenizerContainer as tokenizer. "
"Set tokenizer=FluxTokenizerContainer.Config(...) in your trainer config."
)
t5_tokenizer = tokenizer.t5_tokenizer
clip_tokenizer = tokenizer.clip_tokenizer

if config.generate_timesteps:
ds = FluxValidationDataset(
dataset_name=config.dataset,
dataset_path=config.dataset_path,
t5_tokenizer=t5_tokenizer,
clip_tokenizer=clip_tokenizer,
tokenizer=tokenizer,
prompt_dropout_prob=config.prompt_dropout_prob,
img_size=config.img_size,
dp_rank=dp_rank,
Expand All @@ -435,8 +419,7 @@ def __init__(
ds = FluxDataset(
dataset_name=config.dataset,
dataset_path=config.dataset_path,
t5_tokenizer=t5_tokenizer,
clip_tokenizer=clip_tokenizer,
tokenizer=tokenizer,
prompt_dropout_prob=config.prompt_dropout_prob,
img_size=config.img_size,
dp_rank=dp_rank,
Expand Down
11 changes: 4 additions & 7 deletions torchtitan/models/flux/inference/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ def inference(config: FluxTrainer.Config):

# Build tokenizers from the config
tokenizer_container = config.tokenizer.build()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no need to call it container -- you actually wouldn't know from this file

t5_tokenizer = tokenizer_container.t5_tokenizer
clip_tokenizer = tokenizer_container.clip_tokenizer

if global_rank == 0:
logger.info("Starting inference...")
Expand All @@ -68,17 +66,16 @@ def inference(config: FluxTrainer.Config):
img_height=16 * (img_size // 16),
img_width=16 * (img_size // 16),
# pyrefly: ignore [missing-attribute]
enable_classifier_free_guidance=config.inference.enable_classifier_free_guidance,
enable_classifier_free_guidance=config.inference.sampling.enable_classifier_free_guidance,
# pyrefly: ignore [missing-attribute]
denoising_steps=config.inference.denoising_steps,
denoising_steps=config.inference.sampling.denoising_steps,
# pyrefly: ignore [missing-attribute]
classifier_free_guidance_scale=config.inference.classifier_free_guidance_scale,
classifier_free_guidance_scale=config.inference.sampling.classifier_free_guidance_scale,
# pyrefly: ignore [bad-argument-type]
model=trainer.model_parts[0],
prompt=prompts[i : i + bs],
autoencoder=trainer.autoencoder,
t5_tokenizer=t5_tokenizer,
clip_tokenizer=clip_tokenizer,
tokenizer=tokenizer_container,
t5_encoder=trainer.t5_encoder,
clip_encoder=trainer.clip_encoder,
)
Expand Down
19 changes: 9 additions & 10 deletions torchtitan/models/flux/inference/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from einops import rearrange
from PIL import ExifTags, Image
from torch import Tensor
from torchtitan.components.tokenizer import BaseTokenizer
from torchtitan.models.flux.model.autoencoder import AutoEncoder
from torchtitan.models.flux.model.hf_embedder import FluxEmbedder
from torchtitan.models.flux.model.model import FluxModel
from torchtitan.models.flux.tokenizer import FluxTokenizerContainer
from torchtitan.models.flux.utils import (
create_position_encoding_for_latents,
generate_noise_latent,
Expand Down Expand Up @@ -78,8 +78,7 @@ def generate_image(
model: FluxModel,
prompt: str | list[str],
autoencoder: AutoEncoder,
t5_tokenizer: BaseTokenizer,
clip_tokenizer: BaseTokenizer,
tokenizer: FluxTokenizerContainer,
t5_encoder: FluxEmbedder,
clip_encoder: FluxEmbedder,
) -> torch.Tensor:
Expand All @@ -92,9 +91,10 @@ def generate_image(
if isinstance(prompt, str):
prompt = [prompt]

# Tokenize the prompt. Unsqueeze to add a batch dimension.
clip_tokens = clip_tokenizer.encode(prompt)
t5_tokens = t5_tokenizer.encode(prompt)
# Tokenize the prompt using the container's encode method.
tokens = tokenizer.encode(prompt)
clip_tokens = tokens["clip_tokens"]
t5_tokens = tokens["t5_tokens"]
if len(prompt) == 1:
# pyrefly: ignore [missing-attribute]
clip_tokens = clip_tokens.unsqueeze(0)
Expand All @@ -117,12 +117,11 @@ def generate_image(
if enable_classifier_free_guidance:
num_images = len(prompt)

empty_clip_tokens = clip_tokenizer.encode("")
empty_t5_tokens = t5_tokenizer.encode("")
empty_tokens = tokenizer.encode("")
# pyrefly: ignore [missing-attribute]
empty_clip_tokens = empty_clip_tokens.repeat(num_images, 1)
empty_clip_tokens = empty_tokens["clip_tokens"].repeat(num_images, 1)
# pyrefly: ignore [missing-attribute]
empty_t5_tokens = empty_t5_tokens.repeat(num_images, 1)
empty_t5_tokens = empty_tokens["t5_tokens"].repeat(num_images, 1)

empty_batch = preprocess_data(
device=device,
Expand Down
Loading
Loading