[HF] Model Definition Conversion Support for FLUX#1582
Conversation
tianyu-l
left a comment
There was a problem hiding this comment.
Picture looks great!
Had some comments.
| except FileNotFoundError: | ||
| index_files = [ | ||
| "model.safetensors.index.json", | ||
| "diffusion_pytorch_model.safetensors.index.json", |
There was a problem hiding this comment.
I think we shouldn't hardcode diffusion model specific file name here. If this is special treatment, let's do it in Flux's own state_dict_adapter (maybe without inheriting this __init__).
|
|
||
| import torch | ||
| import torch.distributed.checkpoint as dcp | ||
| import torchtitan.experiments.flux # noqa: F401 |
There was a problem hiding this comment.
why do we need this but didn't need to import llama?
| return state_dict | ||
|
|
||
|
|
||
| def build_flux_state_dict_adapter( |
There was a problem hiding this comment.
Do we need this function? If we want to make field in TrainSpec like ``build_xxxx" , and not plug in the class directly, we could use this function, but I didn't see it used anywhere now
…lux sd_adapter builder
| self.timestep_cycle = itertools.cycle(val_timesteps) | ||
|
|
||
| # Disable classifier free guidance for validation | ||
| self.job_config.training.classifier_free_guidance_prob = 0.0 |
There was a problem hiding this comment.
I see only appearance of this field at
You already called super().__init__() above, will this still take effect?
Also it sounds to me that you might have to pass this in the constructor, unless you have other better ideas?
There was a problem hiding this comment.
You're right in the above implementation this line unintentionally disabled using dropout during training as well, since the job_config is globally changed.
I switched this override to happen inside flux validator.validate now and switch back at the end
|
|
||
| import torch | ||
|
|
||
| logger = logging.getLogger() |
There was a problem hiding this comment.
nit, it's better to put this line after the last import (line 21 in this case).
4c568b2 to
f8b41e9
Compare
f8b41e9 to
52440c1
Compare
Removed the redundant from_hf_map_combine mapping in favor of reversed_combination_plan. The reasoning is from_hf_map_combine forces us to only be able to combine from one direction tt->hf or hf->tt. Now if we need to do both combine and splitting in a conversion then we can just check for the key in combination_plan to know we need to split or check for key in reversed_combination_plan to know we need to combine.
02392ed to
7789949
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
didn't read into conversion map details, but the organization looks good to me.
This PR adds the `FluxStateDictAdapter`, allowing us to convert checkpoints to and from HF. Additional changes: - Modifies `download_hf_assets` script to support downloading diffusion-type safetensor files - Registers Flux's `TrainSpec` in `convert_from_hf` and `convert_to_hf` so that conversion script can be reused - e.g. `python ./scripts/checkpoint_conversion/convert_from_hf.py ./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux --model_flavor flux-dev` Tests: Performing KL divergence test on the forward pass of converted weights loaded in `torchtitan` and HF weights loaded with HF `FluxTransformer2DModel`, we get: ``` Average loss for test from_hf is 7.233546986222528e-13 ``` Addiitonally, we can now run inference with HF weights to verify changes made in pytorch#1548 ### Batched Inference on TorchTitan: | | prompt0 | prompt1 | prompt2 | | --- | --- | --- | --- | | no CFG | <img width="1024" height="1024" alt="prompt0_nocfg" src="https://github.com/user-attachments/assets/421fab49-239a-4ca2-b51a-16823d89acfd" /> | <img width="1024" height="1024" alt="prompt1_nocfg" src="https://github.com/user-attachments/assets/534b557e-7b93-4f2e-b3b3-3a0c7cf57c40" /> | <img width="1024" height="1024" alt="prompt2_nocfg" src="https://github.com/user-attachments/assets/d0f33526-f95d-47db-b5a6-6200bfa151f9" /> | | CFG | <img width="1024" height="1024" alt="prompt0_cfg" src="https://github.com/user-attachments/assets/83234675-eb47-4785-abe1-0f07dd854f1c" /> | <img width="1024" height="1024" alt="prompt1_cfg" src="https://github.com/user-attachments/assets/5e76f3e7-0ca3-47a4-a0ef-3c7e983e8c2c" /> | <img width="1024" height="1024" alt="prompt2_cfg" src="https://github.com/user-attachments/assets/c8cbe367-d96e-4559-a201-48e8dc3d18ee" /> |
This PR adds the `FluxStateDictAdapter`, allowing us to convert checkpoints to and from HF. Additional changes: - Modifies `download_hf_assets` script to support downloading diffusion-type safetensor files - Registers Flux's `TrainSpec` in `convert_from_hf` and `convert_to_hf` so that conversion script can be reused - e.g. `python ./scripts/checkpoint_conversion/convert_from_hf.py ./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux --model_flavor flux-dev` Tests: Performing KL divergence test on the forward pass of converted weights loaded in `torchtitan` and HF weights loaded with HF `FluxTransformer2DModel`, we get: ``` Average loss for test from_hf is 7.233546986222528e-13 ``` Addiitonally, we can now run inference with HF weights to verify changes made in pytorch#1548 ### Batched Inference on TorchTitan: | | prompt0 | prompt1 | prompt2 | | --- | --- | --- | --- | | no CFG | <img width="1024" height="1024" alt="prompt0_nocfg" src="https://github.com/user-attachments/assets/421fab49-239a-4ca2-b51a-16823d89acfd" /> | <img width="1024" height="1024" alt="prompt1_nocfg" src="https://github.com/user-attachments/assets/534b557e-7b93-4f2e-b3b3-3a0c7cf57c40" /> | <img width="1024" height="1024" alt="prompt2_nocfg" src="https://github.com/user-attachments/assets/d0f33526-f95d-47db-b5a6-6200bfa151f9" /> | | CFG | <img width="1024" height="1024" alt="prompt0_cfg" src="https://github.com/user-attachments/assets/83234675-eb47-4785-abe1-0f07dd854f1c" /> | <img width="1024" height="1024" alt="prompt1_cfg" src="https://github.com/user-attachments/assets/5e76f3e7-0ca3-47a4-a0ef-3c7e983e8c2c" /> | <img width="1024" height="1024" alt="prompt2_cfg" src="https://github.com/user-attachments/assets/c8cbe367-d96e-4559-a201-48e8dc3d18ee" /> |
This PR adds the
FluxStateDictAdapter, allowing us to convert checkpoints to and from HF.Additional changes:
download_hf_assetsscript to support downloading diffusion-type safetensor filesTrainSpecinconvert_from_hfandconvert_to_hfso that conversion script can be reusedpython ./scripts/checkpoint_conversion/convert_from_hf.py ./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux --model_flavor flux-devTests:
Performing KL divergence test on the forward pass of converted weights loaded in
torchtitanand HF weights loaded with HFFluxTransformer2DModel, we get:Addiitonally, we can now run inference with HF weights to verify changes made in #1548
Batched Inference on TorchTitan: