Skip to content

[Flux] Refactor flux config: Add FluxTokenizerContainer, remove Validation class#2533

Open
wwwjn wants to merge 5 commits intomainfrom
flux-tokenizer
Open

[Flux] Refactor flux config: Add FluxTokenizerContainer, remove Validation class#2533
wwwjn wants to merge 5 commits intomainfrom
flux-tokenizer

Conversation

@wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Mar 10, 2026

  • Add FluxTokenizerContainer: Enforce tokenizer field to be not None
  • Separate tokenizer and encoder configs: Split the old Encoder class into FluxTokenizerContainer.Config (tokenizer paths) and FluxEncoderConfig (encoder paths + autoencoder)
  • Renamed classifier_free_guidance_probprompt_dropout_prob: Avoids confusion with inference-time CFG settings (enable_classifier_free_guidance, classifier_free_guidance_scale)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 10, 2026
@wwwjn
Copy link
Contributor Author

wwwjn commented Mar 10, 2026

@claude Review PR

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe show evidence that numerics didn't change?

"""How many denoising steps to sample when generating an image"""
eval_freq: int = 100
"""Frequency of evaluation/sampling during training"""
random_init: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
random_init: bool = False
_random_init: bool = False

since it's for test only

Comment on lines +51 to +59
def encode(self, *args, **kwargs) -> list[int]:
raise NotImplementedError(
"Use t5_tokenizer.encode() or clip_tokenizer.encode() directly"
)

def decode(self, *args, **kwargs) -> str:
raise NotImplementedError(
"Use t5_tokenizer.decode() or clip_tokenizer.decode() directly"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's implement encode/decode as wrapper functions over t5 and clip encode/decode to return a tuple / dict of results, so that the dataloader / model never needs to call t5 / clip tokenizers directly.

"""Similar to above, this is a hack to get the test tokenizer asset paths."""
def __post_init__(self):
if self.generate_timesteps and self.prompt_dropout_prob != 0.0:
logger.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's error out instead of print warning

"""Batch size for inference"""
img_size: int = 256
"""Image size for inference"""
enable_classifier_free_guidance: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we are putting these configs in both validator.config and inference.config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to a shared config as claude suggested, these duplicated fields are used for generating / sampling a image in both validation and inference. h

"""Classifier-free guidance with probability `p` to dropout each text encoding independently.
If `n` text encoders are used, the unconditional model is trained in `p ^ n` of all steps.
For example, if `n = 2` and `p = 0.447`, the unconditional model is trained in 20% of all steps"""
prompt_dropout_prob: float = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How legit is this name change -- what does FLUX repo use?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flux repo does have dataloader function, they only opensource the inference / sampling related code, so we don't have direct corresponding name for this field

@wwwjn
Copy link
Contributor Author

wwwjn commented Mar 10, 2026

@claude review

@claude
Copy link

claude bot commented Mar 10, 2026

Claude finished @wwwjn's task in 2m 34s —— View job


PR Review: [Flux] Refactor flux config

  • Read repo guidelines (CLAUDE.md)
  • Read PR diff and changed files
  • Analyze code changes
  • Post review feedback

Summary

This PR separates tokenizer config from encoder config, introduces FluxTokenizerContainer as a proper BaseTokenizer subclass, removes the old Validation dataclass (folding its fields into FluxValidator.Config and Inference), and renames classifier_free_guidance_prob to prompt_dropout_prob. The overall direction is good — it eliminates the hacky Encoder config being used for tokenizer paths and removes the TODO: remove the optional flag once Flux tokenizer is modeled properly from the base Trainer.


Issues

1. Bug: tokenizer_path kwarg is silently ignored by FluxTokenizerContainer

In torchtitan/trainer.py:234, the base Trainer calls:

self.tokenizer = config.tokenizer.build(tokenizer_path=config.hf_assets_path)

The Configurable.Config.build() method at torchtitan/config/configurable.py:108-138 checks whether kwargs are config fields or constructor arguments. Since tokenizer_path is not a field on FluxTokenizerContainer.Config, it gets routed as a constructor kwarg to FluxTokenizerContainer.__init__(config, **kwargs) — where **kwargs silently absorbs it.

This means hf_assets_path is completely ignored for FluxTokenizerContainer. In test mode, the test tokenizer paths (t5_tokenizer_path, clip_tokenizer_path) need to point to the test asset directory. You've handled this in the test configs by setting explicit t5_tokenizer_path/clip_tokenizer_path in the config, but:

  • In infer.py:42, you call config.tokenizer.build() without tokenizer_path at all, which works but only because the kwarg was already being ignored.
  • The hf_assets_path field on FluxTrainer.Config is now dead code for tokenizer purposes.

Consider either: (a) making FluxTokenizerContainer.__init__ explicitly accept and use tokenizer_path as a fallback for test mode, or (b) documenting that hf_assets_path no longer applies to tokenizer paths and updating flux_debugmodel() to not set it (or only set it for the HuggingFaceTokenizer base in non-Flux trainers).

2. Bug: test_mode doesn't override tokenizer paths to use hf_assets_path

In the old code, build_flux_tokenizer had:

if encoder_config.test_mode:
    tokenizer_class = FluxTestTokenizer
    t5_tokenizer_path = clip_tokenzier_path = hf_assets_path

So test_mode would override both the tokenizer class and the paths to point to hf_assets_path. In the new FluxTokenizerContainer.__init__ (tokenizer.py:36-49), test_mode only switches the tokenizer class but doesn't override the paths. The caller must now separately set t5_tokenizer_path and clip_tokenizer_path in the config to the test asset path.

The integration test (tests/integration_tests/flux.py:17-18) does pass these paths explicitly:

--tokenizer.t5_tokenizer_path tests/assets/tokenizer
--tokenizer.clip_tokenizer_path tests/assets/tokenizer

But this is fragile — test_mode=True without correct paths would silently use production tokenizer paths with FluxTestTokenizer, which would likely fail in unexpected ways. Consider adding a validation in FluxTokenizerContainer.__init__ or Config.__post_init__ when test_mode=True.

Fix this →

3. Duplication: Tokenizer and encoder paths are specified independently

In config_registry.py, every config now specifies the same model paths in both tokenizer and encoder:

tokenizer=FluxTokenizerContainer.Config(
    t5_tokenizer_path="google/t5-v1_1-xxl",    # path A
    clip_tokenizer_path="openai/clip-vit-large-patch14",  # path B
),
encoder=FluxEncoderConfig(
    t5_encoder="google/t5-v1_1-xxl",    # same as path A
    clip_encoder="openai/clip-vit-large-patch14",  # same as path B
),

This is a step backward from the old Encoder class where these were unified. A user could easily set different T5 tokenizer vs encoder paths, which would silently produce incorrect results. This addresses @tianyu-l's comment about configs being in both validator and inference config.

Consider having a single source of truth for model paths. For example, FluxEncoderConfig could include the tokenizer paths and FluxTokenizerContainer.Config could reference or derive from it, or the tokenizer config could simply derive paths from the encoder config at build time.

4. FluxValidator.Config has duplicated CFG fields with Inference

Both FluxValidator.Config (lines 68-91 in validate.py) and Inference (lines 38-43 in configs.py) now have the same three fields:

  • enable_classifier_free_guidance
  • classifier_free_guidance_scale
  • denoising_steps

This is a 1:1 duplication. Consider extracting these into a shared dataclass (e.g., SamplingConfig) that both FluxValidator.Config and Inference can compose.

Fix this →

5. Agree with @tianyu-l: _random_init naming and __post_init__ warning

FluxEncoderConfig.random_init (configs.py:22) — agree with @tianyu-l's suggestion to prefix with underscore (_random_init) since it's test-only and shouldn't be a user-facing CLI flag.

Similarly for FluxTokenizerContainer.Config.test_mode (tokenizer.py:34) — consider _test_mode.

For the __post_init__ warning in FluxDataLoader.Config (flux_datasets.py:394-400), agree with @tianyu-l that this should raise an error rather than silently overriding:

def __post_init__(self):
    if self.generate_timesteps and self.prompt_dropout_prob != 0.0:
        raise ValueError(
            f"prompt_dropout_prob must be 0.0 for validation "
            f"(generate_timesteps=True), got {self.prompt_dropout_prob}"
        )

6. FluxTokenizerContainer.encode()/decode() raise NotImplementedError

tokenizer.py:51-58FluxTokenizerContainer inherits from BaseTokenizer which has encode/decode as @abstractmethod. The implementation raises NotImplementedError. This is acceptable for now, but the error message could be more helpful. Agree with @tianyu-l's comment that implementing proper wrapper encode/decode that return combined results would be cleaner and make the container usable as a real tokenizer rather than just a holder.

7. Minor: prompt_dropout_prob naming

Regarding @tianyu-l's comment about what the FLUX repo uses — the original FLUX repo (black-forest-labs/flux) doesn't have this concept since it's inference-only. The SD3 paper and diffusers library use "dropout" terminology for this (e.g., text_encoder_dropout). The rename from classifier_free_guidance_prob to prompt_dropout_prob is reasonable since it avoids confusion with inference-time CFG scale settings.

Positive aspects

  • Removes the Optional[BaseTokenizer.Config] hack from Trainer.Config — this is a meaningful cleanup for the base framework
  • generate_image() in sampling.py no longer takes the entire FluxTrainer.Config — passing explicit parameters is cleaner and makes the function's dependencies obvious
  • FluxValidator.flux_init() now takes dump_folder: str instead of the full FluxTrainer.Config — reduces coupling
  • The __post_init__ enforcement of prompt_dropout_prob=0 for validation is a good safety net (though should be an error per point 5)
  • README typo fix (.run_train.sh./run_train.sh)

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I didn't see / check numerics results and assumed they are bitwise identical to before.

It's probably worth to have a loss test, because this diffusion model is using many different components than llm.

)

# Build tokenizers from the config
tokenizer_container = config.tokenizer.build()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no need to call it container -- you actually wouldn't know from this file

Comment on lines +62 to +63
"clip_tokens": self.clip_tokenizer.encode(text),
"t5_tokens": self.t5_tokenizer.encode(text),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we can just call them

Suggested change
"clip_tokens": self.clip_tokenizer.encode(text),
"t5_tokens": self.t5_tokenizer.encode(text),
"clip": self.clip_tokenizer.encode(text),
"t5": self.t5_tokenizer.encode(text),

Comment on lines +77 to +80
if "t5_tokens" in tokens:
result["t5_text"] = self.t5_tokenizer.decode(tokens["t5_tokens"])
if "clip_tokens" in tokens:
result["clip_text"] = self.clip_tokenizer.decode(tokens["clip_tokens"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if "t5_tokens" in tokens:
result["t5_text"] = self.t5_tokenizer.decode(tokens["t5_tokens"])
if "clip_tokens" in tokens:
result["clip_text"] = self.clip_tokenizer.decode(tokens["clip_tokens"])
if "t5" in tokens:
result["t5"] = self.t5_tokenizer.decode(tokens["t5"])
if "clip_tokens" in tokens:
result["clip"] = self.clip_tokenizer.decode(tokens["clip"])

Returns:
A dict with keys "clip_text" and/or "t5_text".
"""
result = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
result = {}
results = {}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants