Skip to content

[HF] Model Definition Conversion Support for FLUX#1582

Merged
wesleytruong merged 4 commits intomainfrom
flux_hf_conversion
Aug 20, 2025
Merged

[HF] Model Definition Conversion Support for FLUX#1582
wesleytruong merged 4 commits intomainfrom
flux_hf_conversion

Conversation

@wesleytruong
Copy link
Contributor

@wesleytruong wesleytruong commented Aug 15, 2025

This PR adds the FluxStateDictAdapter, allowing us to convert checkpoints to and from HF.

Additional changes:

  • Modifies download_hf_assets script to support downloading diffusion-type safetensor files
  • Registers Flux's TrainSpec in convert_from_hf and convert_to_hf so that conversion script can be reused
    • e.g. python ./scripts/checkpoint_conversion/convert_from_hf.py ./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux --model_flavor flux-dev

Tests:
Performing KL divergence test on the forward pass of converted weights loaded in torchtitan and HF weights loaded with HF FluxTransformer2DModel, we get:

Average loss for test from_hf is 7.233546986222528e-13

Addiitonally, we can now run inference with HF weights to verify changes made in #1548

Batched Inference on TorchTitan:

prompt0 prompt1 prompt2
no CFG prompt0_nocfg prompt1_nocfg prompt2_nocfg
CFG prompt0_cfg prompt1_cfg prompt2_cfg

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 15, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Picture looks great!

Had some comments.

except FileNotFoundError:
index_files = [
"model.safetensors.index.json",
"diffusion_pytorch_model.safetensors.index.json",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't hardcode diffusion model specific file name here. If this is special treatment, let's do it in Flux's own state_dict_adapter (maybe without inheriting this __init__).


import torch
import torch.distributed.checkpoint as dcp
import torchtitan.experiments.flux # noqa: F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this but didn't need to import llama?

return state_dict


def build_flux_state_dict_adapter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this function? If we want to make field in TrainSpec like ``build_xxxx" , and not plug in the class directly, we could use this function, but I didn't see it used anywhere now

self.timestep_cycle = itertools.cycle(val_timesteps)

# Disable classifier free guidance for validation
self.job_config.training.classifier_free_guidance_prob = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see only appearance of this field at

dropout_prob = self.job_config.training.classifier_free_guidance_prob

You already called super().__init__() above, will this still take effect?

Also it sounds to me that you might have to pass this in the constructor, unless you have other better ideas?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right in the above implementation this line unintentionally disabled using dropout during training as well, since the job_config is globally changed.

I switched this override to happen inside flux validator.validate now and switch back at the end


import torch

logger = logging.getLogger()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, it's better to put this line after the last import (line 21 in this case).

Removed the redundant from_hf_map_combine mapping in favor of reversed_combination_plan.

The reasoning is from_hf_map_combine forces us to only be able to combine from one direction tt->hf or hf->tt. Now if we need to do both combine and splitting in a conversion then we can just check for the key in combination_plan to know we need to split or check for key in reversed_combination_plan to know we need to combine.
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't read into conversion map details, but the organization looks good to me.

@wesleytruong wesleytruong merged commit c0b2e5a into main Aug 20, 2025
9 checks passed
@tianyu-l tianyu-l deleted the flux_hf_conversion branch August 20, 2025 20:36
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
This PR adds the `FluxStateDictAdapter`, allowing us to convert
checkpoints to and from HF.

Additional changes:
- Modifies `download_hf_assets` script to support downloading
diffusion-type safetensor files
- Registers Flux's `TrainSpec` in `convert_from_hf` and `convert_to_hf`
so that conversion script can be reused
- e.g. `python ./scripts/checkpoint_conversion/convert_from_hf.py
./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux
--model_flavor flux-dev`

Tests:
Performing KL divergence test on the forward pass of converted weights
loaded in `torchtitan` and HF weights loaded with HF
`FluxTransformer2DModel`, we get:
```
Average loss for test from_hf is 7.233546986222528e-13
```

Addiitonally, we can now run inference with HF weights to verify changes
made in pytorch#1548

### Batched Inference on TorchTitan:
| | prompt0 | prompt1 | prompt2 |
| --- | --- | --- | --- |
| no CFG | <img width="1024" height="1024" alt="prompt0_nocfg"
src="https://github.com/user-attachments/assets/421fab49-239a-4ca2-b51a-16823d89acfd"
/> | <img width="1024" height="1024" alt="prompt1_nocfg"
src="https://github.com/user-attachments/assets/534b557e-7b93-4f2e-b3b3-3a0c7cf57c40"
/> | <img width="1024" height="1024" alt="prompt2_nocfg"
src="https://github.com/user-attachments/assets/d0f33526-f95d-47db-b5a6-6200bfa151f9"
/> |
| CFG | <img width="1024" height="1024" alt="prompt0_cfg"
src="https://github.com/user-attachments/assets/83234675-eb47-4785-abe1-0f07dd854f1c"
/> | <img width="1024" height="1024" alt="prompt1_cfg"
src="https://github.com/user-attachments/assets/5e76f3e7-0ca3-47a4-a0ef-3c7e983e8c2c"
/> | <img width="1024" height="1024" alt="prompt2_cfg"
src="https://github.com/user-attachments/assets/c8cbe367-d96e-4559-a201-48e8dc3d18ee"
/> |
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
This PR adds the `FluxStateDictAdapter`, allowing us to convert
checkpoints to and from HF.

Additional changes:
- Modifies `download_hf_assets` script to support downloading
diffusion-type safetensor files
- Registers Flux's `TrainSpec` in `convert_from_hf` and `convert_to_hf`
so that conversion script can be reused
- e.g. `python ./scripts/checkpoint_conversion/convert_from_hf.py
./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux
--model_flavor flux-dev`

Tests:
Performing KL divergence test on the forward pass of converted weights
loaded in `torchtitan` and HF weights loaded with HF
`FluxTransformer2DModel`, we get:
```
Average loss for test from_hf is 7.233546986222528e-13
```

Addiitonally, we can now run inference with HF weights to verify changes
made in pytorch#1548

### Batched Inference on TorchTitan:
| | prompt0 | prompt1 | prompt2 |
| --- | --- | --- | --- |
| no CFG | <img width="1024" height="1024" alt="prompt0_nocfg"
src="https://github.com/user-attachments/assets/421fab49-239a-4ca2-b51a-16823d89acfd"
/> | <img width="1024" height="1024" alt="prompt1_nocfg"
src="https://github.com/user-attachments/assets/534b557e-7b93-4f2e-b3b3-3a0c7cf57c40"
/> | <img width="1024" height="1024" alt="prompt2_nocfg"
src="https://github.com/user-attachments/assets/d0f33526-f95d-47db-b5a6-6200bfa151f9"
/> |
| CFG | <img width="1024" height="1024" alt="prompt0_cfg"
src="https://github.com/user-attachments/assets/83234675-eb47-4785-abe1-0f07dd854f1c"
/> | <img width="1024" height="1024" alt="prompt1_cfg"
src="https://github.com/user-attachments/assets/5e76f3e7-0ca3-47a4-a0ef-3c7e983e8c2c"
/> | <img width="1024" height="1024" alt="prompt2_cfg"
src="https://github.com/user-attachments/assets/c8cbe367-d96e-4559-a201-48e8dc3d18ee"
/> |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants