[Flux] Refactor flux config: Add FluxTokenizerContainer, remove Validation class#2533
[Flux] Refactor flux config: Add FluxTokenizerContainer, remove Validation class#2533
Conversation
|
@claude Review PR |
tianyu-l
left a comment
There was a problem hiding this comment.
maybe show evidence that numerics didn't change?
torchtitan/models/flux/configs.py
Outdated
| """How many denoising steps to sample when generating an image""" | ||
| eval_freq: int = 100 | ||
| """Frequency of evaluation/sampling during training""" | ||
| random_init: bool = False |
There was a problem hiding this comment.
| random_init: bool = False | |
| _random_init: bool = False |
since it's for test only
torchtitan/models/flux/tokenizer.py
Outdated
| def encode(self, *args, **kwargs) -> list[int]: | ||
| raise NotImplementedError( | ||
| "Use t5_tokenizer.encode() or clip_tokenizer.encode() directly" | ||
| ) | ||
|
|
||
| def decode(self, *args, **kwargs) -> str: | ||
| raise NotImplementedError( | ||
| "Use t5_tokenizer.decode() or clip_tokenizer.decode() directly" | ||
| ) |
There was a problem hiding this comment.
Let's implement encode/decode as wrapper functions over t5 and clip encode/decode to return a tuple / dict of results, so that the dataloader / model never needs to call t5 / clip tokenizers directly.
| """Similar to above, this is a hack to get the test tokenizer asset paths.""" | ||
| def __post_init__(self): | ||
| if self.generate_timesteps and self.prompt_dropout_prob != 0.0: | ||
| logger.warning( |
There was a problem hiding this comment.
let's error out instead of print warning
torchtitan/models/flux/configs.py
Outdated
| """Batch size for inference""" | ||
| img_size: int = 256 | ||
| """Image size for inference""" | ||
| enable_classifier_free_guidance: bool = False |
There was a problem hiding this comment.
so we are putting these configs in both validator.config and inference.config?
There was a problem hiding this comment.
Moved to a shared config as claude suggested, these duplicated fields are used for generating / sampling a image in both validation and inference. h
| """Classifier-free guidance with probability `p` to dropout each text encoding independently. | ||
| If `n` text encoders are used, the unconditional model is trained in `p ^ n` of all steps. | ||
| For example, if `n = 2` and `p = 0.447`, the unconditional model is trained in 20% of all steps""" | ||
| prompt_dropout_prob: float = 0.0 |
There was a problem hiding this comment.
How legit is this name change -- what does FLUX repo use?
There was a problem hiding this comment.
Flux repo does have dataloader function, they only opensource the inference / sampling related code, so we don't have direct corresponding name for this field
|
@claude review |
|
Claude finished @wwwjn's task in 2m 34s —— View job PR Review: [Flux] Refactor flux config
SummaryThis PR separates tokenizer config from encoder config, introduces Issues1. Bug:
|
| ) | ||
|
|
||
| # Build tokenizers from the config | ||
| tokenizer_container = config.tokenizer.build() |
There was a problem hiding this comment.
nit: no need to call it container -- you actually wouldn't know from this file
| "clip_tokens": self.clip_tokenizer.encode(text), | ||
| "t5_tokens": self.t5_tokenizer.encode(text), |
There was a problem hiding this comment.
Do you think we can just call them
| "clip_tokens": self.clip_tokenizer.encode(text), | |
| "t5_tokens": self.t5_tokenizer.encode(text), | |
| "clip": self.clip_tokenizer.encode(text), | |
| "t5": self.t5_tokenizer.encode(text), |
| if "t5_tokens" in tokens: | ||
| result["t5_text"] = self.t5_tokenizer.decode(tokens["t5_tokens"]) | ||
| if "clip_tokens" in tokens: | ||
| result["clip_text"] = self.clip_tokenizer.decode(tokens["clip_tokens"]) |
There was a problem hiding this comment.
| if "t5_tokens" in tokens: | |
| result["t5_text"] = self.t5_tokenizer.decode(tokens["t5_tokens"]) | |
| if "clip_tokens" in tokens: | |
| result["clip_text"] = self.clip_tokenizer.decode(tokens["clip_tokens"]) | |
| if "t5" in tokens: | |
| result["t5"] = self.t5_tokenizer.decode(tokens["t5"]) | |
| if "clip_tokens" in tokens: | |
| result["clip"] = self.clip_tokenizer.decode(tokens["clip"]) |
| Returns: | ||
| A dict with keys "clip_text" and/or "t5_text". | ||
| """ | ||
| result = {} |
There was a problem hiding this comment.
nit
| result = {} | |
| results = {} |
classifier_free_guidance_prob→prompt_dropout_prob: Avoids confusion with inference-time CFG settings (enable_classifier_free_guidance, classifier_free_guidance_scale)