Skip to content

Flux Batched Inference#1548

Merged
wesleytruong merged 3 commits intomainfrom
flux_inference
Aug 12, 2025
Merged

Flux Batched Inference#1548
wesleytruong merged 3 commits intomainfrom
flux_inference

Conversation

@wesleytruong
Copy link
Contributor

This PR reuses the trainer's parallelization to perform batched inference on Flux. It follows up from and addresses comments on @CarlosGomes98 #1227. This irons out some optimizations for better code clarity.

In this implementation saves images locally from each rank but maintains the global information such as prompt and idx of each image. It also removes unnecessary padding and handles prompt sets that are not divisible by the batch size.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 9, 2025
Comment on lines 17 to 20
overrides=""
if [ $# -ne 0 ]; then
overrides="$*"
fi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be improved, see run_train.sh

pred = pred_u + classifier_free_guidance_scale * (pred_c - pred_u)

# repeat along batch dimension to update both unconditional and conditional latents
pred = pred.repeat(2, 1, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait what happens before and after for non-batched sampling?
Did we do repeat (unconditional + conditional) before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The repeat for pred was previously getting handled automatically by broadcast here latents = latents + (t_prev - t_curr) * pred.

Previously, the latents would always be batch size 1 or 2 depending on if it is using cfg, and pred is always batch size 1. Since the new implementation accepts batch sizes of greater than 1, latents and pred won't be broadcast automatically in the case of cfg since latents will have twice the batch size of pred after pred_u and pred_c are combined , and broadcast only works if pred's batch size is 1.

return self.tiktokenizer.vocab_size

def encode(self, text: str) -> torch.Tensor:
def encode(self, text: str | list[str]) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you are doing batchify within tokenizer. Can't we rely on dataloader to do this for us, as what happens in training?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we are not splitting up the data into batches, but rather accepting a batch (list of strings) and encoding the entire thing to return as tensor. I did this since this way matches the behavior of the real FluxTokenizer and I think it's also more convenient to do this than to encode every string in a batch and then convert the returned lists of ids into tensor format in the program that calls encode

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great, sounds good to me.

@@ -0,0 +1,110 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVIDIA??? lol

from torchtitan.tools.logging import init_logger, logger


def inference_call(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is too minimal to be a standalone function.

Also can do generate + save for each batch? Not sure why it helps to generate all and then save all.

bs=config.inference.batch_size,
)

if config.inference.save_img_folder:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why if?

Comment on lines 94 to 97
output_dir=os.path.join(
config.job.dump_folder,
config.inference.save_img_folder,
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only need to do this once outside the for-loop

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks solid!

@wesleytruong wesleytruong merged commit d14f1e3 into main Aug 12, 2025
6 checks passed
@tianyu-l tianyu-l deleted the flux_inference branch August 12, 2025 05:50
wesleytruong added a commit that referenced this pull request Aug 20, 2025
This PR adds the `FluxStateDictAdapter`, allowing us to convert
checkpoints to and from HF.

Additional changes:
- Modifies `download_hf_assets` script to support downloading
diffusion-type safetensor files
- Registers Flux's `TrainSpec` in `convert_from_hf` and `convert_to_hf`
so that conversion script can be reused
- e.g. `python ./scripts/checkpoint_conversion/convert_from_hf.py
./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux
--model_flavor flux-dev`

Tests:
Performing KL divergence test on the forward pass of converted weights
loaded in `torchtitan` and HF weights loaded with HF
`FluxTransformer2DModel`, we get:
```
Average loss for test from_hf is 7.233546986222528e-13
```

Addiitonally, we can now run inference with HF weights to verify changes
made in #1548

### Batched Inference on TorchTitan:
| | prompt0 | prompt1 | prompt2 |
| --- | --- | --- | --- |
| no CFG | <img width="1024" height="1024" alt="prompt0_nocfg"
src="https://github.com/user-attachments/assets/421fab49-239a-4ca2-b51a-16823d89acfd"
/> | <img width="1024" height="1024" alt="prompt1_nocfg"
src="https://github.com/user-attachments/assets/534b557e-7b93-4f2e-b3b3-3a0c7cf57c40"
/> | <img width="1024" height="1024" alt="prompt2_nocfg"
src="https://github.com/user-attachments/assets/d0f33526-f95d-47db-b5a6-6200bfa151f9"
/> |
| CFG | <img width="1024" height="1024" alt="prompt0_cfg"
src="https://github.com/user-attachments/assets/83234675-eb47-4785-abe1-0f07dd854f1c"
/> | <img width="1024" height="1024" alt="prompt1_cfg"
src="https://github.com/user-attachments/assets/5e76f3e7-0ca3-47a4-a0ef-3c7e983e8c2c"
/> | <img width="1024" height="1024" alt="prompt2_cfg"
src="https://github.com/user-attachments/assets/c8cbe367-d96e-4559-a201-48e8dc3d18ee"
/> |
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
This PR reuses the trainer's parallelization to perform batched
inference on Flux. It follows up from and addresses comments on
@CarlosGomes98 pytorch#1227. This
irons out some optimizations for better code clarity.

In this implementation saves images locally from each rank but maintains
the global information such as prompt and idx of each image. It also
removes unnecessary padding and handles prompt sets that are not
divisible by the batch size.
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
This PR adds the `FluxStateDictAdapter`, allowing us to convert
checkpoints to and from HF.

Additional changes:
- Modifies `download_hf_assets` script to support downloading
diffusion-type safetensor files
- Registers Flux's `TrainSpec` in `convert_from_hf` and `convert_to_hf`
so that conversion script can be reused
- e.g. `python ./scripts/checkpoint_conversion/convert_from_hf.py
./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux
--model_flavor flux-dev`

Tests:
Performing KL divergence test on the forward pass of converted weights
loaded in `torchtitan` and HF weights loaded with HF
`FluxTransformer2DModel`, we get:
```
Average loss for test from_hf is 7.233546986222528e-13
```

Addiitonally, we can now run inference with HF weights to verify changes
made in pytorch#1548

### Batched Inference on TorchTitan:
| | prompt0 | prompt1 | prompt2 |
| --- | --- | --- | --- |
| no CFG | <img width="1024" height="1024" alt="prompt0_nocfg"
src="https://github.com/user-attachments/assets/421fab49-239a-4ca2-b51a-16823d89acfd"
/> | <img width="1024" height="1024" alt="prompt1_nocfg"
src="https://github.com/user-attachments/assets/534b557e-7b93-4f2e-b3b3-3a0c7cf57c40"
/> | <img width="1024" height="1024" alt="prompt2_nocfg"
src="https://github.com/user-attachments/assets/d0f33526-f95d-47db-b5a6-6200bfa151f9"
/> |
| CFG | <img width="1024" height="1024" alt="prompt0_cfg"
src="https://github.com/user-attachments/assets/83234675-eb47-4785-abe1-0f07dd854f1c"
/> | <img width="1024" height="1024" alt="prompt1_cfg"
src="https://github.com/user-attachments/assets/5e76f3e7-0ca3-47a4-a0ef-3c7e983e8c2c"
/> | <img width="1024" height="1024" alt="prompt2_cfg"
src="https://github.com/user-attachments/assets/c8cbe367-d96e-4559-a201-48e8dc3d18ee"
/> |
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
This PR reuses the trainer's parallelization to perform batched
inference on Flux. It follows up from and addresses comments on
@CarlosGomes98 pytorch#1227. This
irons out some optimizations for better code clarity.

In this implementation saves images locally from each rank but maintains
the global information such as prompt and idx of each image. It also
removes unnecessary padding and handles prompt sets that are not
divisible by the batch size.
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
This PR adds the `FluxStateDictAdapter`, allowing us to convert
checkpoints to and from HF.

Additional changes:
- Modifies `download_hf_assets` script to support downloading
diffusion-type safetensor files
- Registers Flux's `TrainSpec` in `convert_from_hf` and `convert_to_hf`
so that conversion script can be reused
- e.g. `python ./scripts/checkpoint_conversion/convert_from_hf.py
./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux
--model_flavor flux-dev`

Tests:
Performing KL divergence test on the forward pass of converted weights
loaded in `torchtitan` and HF weights loaded with HF
`FluxTransformer2DModel`, we get:
```
Average loss for test from_hf is 7.233546986222528e-13
```

Addiitonally, we can now run inference with HF weights to verify changes
made in pytorch#1548

### Batched Inference on TorchTitan:
| | prompt0 | prompt1 | prompt2 |
| --- | --- | --- | --- |
| no CFG | <img width="1024" height="1024" alt="prompt0_nocfg"
src="https://github.com/user-attachments/assets/421fab49-239a-4ca2-b51a-16823d89acfd"
/> | <img width="1024" height="1024" alt="prompt1_nocfg"
src="https://github.com/user-attachments/assets/534b557e-7b93-4f2e-b3b3-3a0c7cf57c40"
/> | <img width="1024" height="1024" alt="prompt2_nocfg"
src="https://github.com/user-attachments/assets/d0f33526-f95d-47db-b5a6-6200bfa151f9"
/> |
| CFG | <img width="1024" height="1024" alt="prompt0_cfg"
src="https://github.com/user-attachments/assets/83234675-eb47-4785-abe1-0f07dd854f1c"
/> | <img width="1024" height="1024" alt="prompt1_cfg"
src="https://github.com/user-attachments/assets/5e76f3e7-0ca3-47a4-a0ef-3c7e983e8c2c"
/> | <img width="1024" height="1024" alt="prompt2_cfg"
src="https://github.com/user-attachments/assets/c8cbe367-d96e-4559-a201-48e8dc3d18ee"
/> |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants