Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add flux toeknizer container
  • Loading branch information
wwwjn committed Mar 10, 2026
commit 317d630f324542c28423c5795b3dece362d99685
19 changes: 19 additions & 0 deletions torchtitan/models/flux/config_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from torchtitan.models.flux.configs import Encoder, Inference, Validation
from torchtitan.models.flux.flux_datasets import FluxDataLoader
from torchtitan.models.flux.tokenizer import FluxTokenizerContainer
from torchtitan.models.flux.trainer import FluxTrainer
from torchtitan.models.flux.validate import FluxValidator

Expand All @@ -31,6 +32,12 @@ def flux_debugmodel() -> FluxTrainer.Config:
hf_assets_path = "tests/assets/tokenizer"
return FluxTrainer.Config(
hf_assets_path=hf_assets_path,
tokenizer=FluxTokenizerContainer.Config(
t5_encoder=encoder.t5_encoder,
clip_encoder=encoder.clip_encoder,
max_t5_encoding_len=encoder.max_t5_encoding_len,
test_mode=encoder.test_mode,
),
metrics=MetricsProcessor.Config(log_freq=1),
model_spec=model_registry("flux-debug"),
optimizer=OptimizersContainer.Config(lr=8e-4),
Expand Down Expand Up @@ -90,6 +97,12 @@ def flux_dev() -> FluxTrainer.Config:
autoencoder_path="assets/hf/FLUX.1-dev/ae.safetensors",
)
return FluxTrainer.Config(
tokenizer=FluxTokenizerContainer.Config(
t5_encoder=encoder.t5_encoder,
clip_encoder=encoder.clip_encoder,
max_t5_encoding_len=encoder.max_t5_encoding_len,
test_mode=encoder.test_mode,
),
metrics=MetricsProcessor.Config(log_freq=100),
model_spec=model_registry("flux-dev"),
optimizer=OptimizersContainer.Config(lr=1e-4),
Expand Down Expand Up @@ -139,6 +152,12 @@ def flux_schnell() -> FluxTrainer.Config:
autoencoder_path="assets/hf/FLUX.1-dev/ae.safetensors",
)
return FluxTrainer.Config(
tokenizer=FluxTokenizerContainer.Config(
t5_encoder=encoder.t5_encoder,
clip_encoder=encoder.clip_encoder,
max_t5_encoding_len=encoder.max_t5_encoding_len,
test_mode=encoder.test_mode,
),
metrics=MetricsProcessor.Config(log_freq=100),
model_spec=model_registry("flux-schnell"),
optimizer=OptimizersContainer.Config(lr=1e-4),
Expand Down
21 changes: 16 additions & 5 deletions torchtitan/models/flux/flux_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
from torchtitan.components.dataloader import ParallelAwareDataloader
from torchtitan.components.tokenizer import BaseTokenizer
from torchtitan.hf_datasets import DatasetConfig
from torchtitan.models.flux.tokenizer import build_flux_tokenizer, FluxTokenizer
from torchtitan.models.flux.tokenizer import (
build_flux_tokenizer,
FluxTokenizer,
FluxTokenizerContainer,
)
from torchtitan.tools.logging import logger

from .configs import Encoder
Expand Down Expand Up @@ -407,13 +411,20 @@ def __init__(
dp_world_size: int,
dp_rank: int,
local_batch_size: int,
tokenizer: BaseTokenizer | None = None,
**kwargs,
):

t5_tokenizer, clip_tokenizer = build_flux_tokenizer(
encoder_config=config.encoder,
hf_assets_path=config.hf_assets_path,
)
# Use tokenizer from trainer if provided (FluxTokenizerContainer)
if tokenizer is not None and isinstance(tokenizer, FluxTokenizerContainer):
t5_tokenizer = tokenizer.t5_tokenizer
clip_tokenizer = tokenizer.clip_tokenizer
else:
# Fallback to existing build logic for backward compatibility
t5_tokenizer, clip_tokenizer = build_flux_tokenizer(
encoder_config=config.encoder,
hf_assets_path=config.hf_assets_path,
)

if config.generate_timesteps:
ds = FluxValidationDataset(
Expand Down
49 changes: 49 additions & 0 deletions torchtitan/models/flux/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.


from dataclasses import dataclass

import torch
from transformers import CLIPTokenizer, T5Tokenizer

Expand All @@ -16,6 +18,53 @@
from .configs import Encoder


class FluxTokenizerContainer(BaseTokenizer):
"""Container holding both T5 and CLIP tokenizers for Flux.

This plugs into Trainer.Config.tokenizer so that tokenizers are built
by the trainer (via Configurable.Config.build) rather than inside the
dataloader.
"""

@dataclass(kw_only=True, slots=True)
class Config(BaseTokenizer.Config):
t5_encoder: str = "google/t5-v1_1-small"
clip_encoder: str = "openai/clip-vit-large-patch14"
max_t5_encoding_len: int = 256
test_mode: bool = False

def __init__(self, config: Config, *, tokenizer_path: str = "", **kwargs):
super().__init__()
# tokenizer_path maps to hf_assets_path (used in test mode)
if config.test_mode:
tokenizer_class = FluxTestTokenizer
t5_path = clip_path = tokenizer_path
else:
tokenizer_class = FluxTokenizer
t5_path = config.t5_encoder
clip_path = config.clip_encoder

self.t5_tokenizer: BaseTokenizer = tokenizer_class(
t5_path, max_length=config.max_t5_encoding_len
)
self.clip_tokenizer: BaseTokenizer = tokenizer_class(
clip_path, max_length=77
)

def encode(self, *args, **kwargs) -> list[int]:
raise NotImplementedError(
"Use t5_tokenizer.encode() or clip_tokenizer.encode() directly"
)

def decode(self, *args, **kwargs) -> str:
raise NotImplementedError(
"Use t5_tokenizer.decode() or clip_tokenizer.decode() directly"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's implement encode/decode as wrapper functions over t5 and clip encode/decode to return a tuple / dict of results, so that the dataloader / model never needs to call t5 / clip tokenizers directly.


def get_vocab_size(self) -> int:
return self.t5_tokenizer.get_vocab_size()


class FluxTestTokenizer(BaseTokenizer):
"""
Flux Tokenizer for test purpose. This is a simple wrapper around the TikTokenizer,
Expand Down
11 changes: 3 additions & 8 deletions torchtitan/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ class Config(Configurable.Config):
metrics: MetricsProcessor.Config = field(
default_factory=MetricsProcessor.Config
)
# TODO: remove the optional flag once Flux tokenizer is modeled properly
tokenizer: BaseTokenizer.Config | None = field(
tokenizer: BaseTokenizer.Config = field(
default_factory=HuggingFaceTokenizer.Config
)
dataloader: BaseDataLoader.Config = field(default_factory=BaseDataLoader.Config)
Expand Down Expand Up @@ -163,7 +162,7 @@ def maybe_log(self) -> None:
parallel_dims: ParallelDims

# swappable training components
tokenizer: BaseTokenizer | None
tokenizer: BaseTokenizer
dataloader: BaseDataLoader
model_config: BaseModel.Config
# TODO: we should make this list[BaseModel / Decoder] but this will affect many components.
Expand Down Expand Up @@ -232,11 +231,7 @@ def __init__(self, config: Config):
)

# build tokenizer
self.tokenizer = (
config.tokenizer.build(tokenizer_path=config.hf_assets_path)
if config.tokenizer is not None
else None
)
self.tokenizer = config.tokenizer.build(tokenizer_path=config.hf_assets_path)

# build dataloader
self.dataloader = config.dataloader.build(
Expand Down