Conversation
| overrides="" | ||
| if [ $# -ne 0 ]; then | ||
| overrides="$*" | ||
| fi |
There was a problem hiding this comment.
this can be improved, see run_train.sh
| pred = pred_u + classifier_free_guidance_scale * (pred_c - pred_u) | ||
|
|
||
| # repeat along batch dimension to update both unconditional and conditional latents | ||
| pred = pred.repeat(2, 1, 1) |
There was a problem hiding this comment.
wait what happens before and after for non-batched sampling?
Did we do repeat (unconditional + conditional) before?
There was a problem hiding this comment.
The repeat for pred was previously getting handled automatically by broadcast here latents = latents + (t_prev - t_curr) * pred.
Previously, the latents would always be batch size 1 or 2 depending on if it is using cfg, and pred is always batch size 1. Since the new implementation accepts batch sizes of greater than 1, latents and pred won't be broadcast automatically in the case of cfg since latents will have twice the batch size of pred after pred_u and pred_c are combined , and broadcast only works if pred's batch size is 1.
| return self.tiktokenizer.vocab_size | ||
|
|
||
| def encode(self, text: str) -> torch.Tensor: | ||
| def encode(self, text: str | list[str]) -> torch.Tensor: |
There was a problem hiding this comment.
It seems you are doing batchify within tokenizer. Can't we rely on dataloader to do this for us, as what happens in training?
There was a problem hiding this comment.
Here we are not splitting up the data into batches, but rather accepting a batch (list of strings) and encoding the entire thing to return as tensor. I did this since this way matches the behavior of the real FluxTokenizer and I think it's also more convenient to do this than to encode every string in a batch and then convert the returned lists of ids into tensor format in the program that calls encode
There was a problem hiding this comment.
great, sounds good to me.
| @@ -0,0 +1,110 @@ | |||
| # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |||
| from torchtitan.tools.logging import init_logger, logger | ||
|
|
||
|
|
||
| def inference_call( |
There was a problem hiding this comment.
I feel this is too minimal to be a standalone function.
Also can do generate + save for each batch? Not sure why it helps to generate all and then save all.
| bs=config.inference.batch_size, | ||
| ) | ||
|
|
||
| if config.inference.save_img_folder: |
| output_dir=os.path.join( | ||
| config.job.dump_folder, | ||
| config.inference.save_img_folder, | ||
| ), |
There was a problem hiding this comment.
only need to do this once outside the for-loop
This PR adds the `FluxStateDictAdapter`, allowing us to convert checkpoints to and from HF. Additional changes: - Modifies `download_hf_assets` script to support downloading diffusion-type safetensor files - Registers Flux's `TrainSpec` in `convert_from_hf` and `convert_to_hf` so that conversion script can be reused - e.g. `python ./scripts/checkpoint_conversion/convert_from_hf.py ./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux --model_flavor flux-dev` Tests: Performing KL divergence test on the forward pass of converted weights loaded in `torchtitan` and HF weights loaded with HF `FluxTransformer2DModel`, we get: ``` Average loss for test from_hf is 7.233546986222528e-13 ``` Addiitonally, we can now run inference with HF weights to verify changes made in #1548 ### Batched Inference on TorchTitan: | | prompt0 | prompt1 | prompt2 | | --- | --- | --- | --- | | no CFG | <img width="1024" height="1024" alt="prompt0_nocfg" src="https://github.com/user-attachments/assets/421fab49-239a-4ca2-b51a-16823d89acfd" /> | <img width="1024" height="1024" alt="prompt1_nocfg" src="https://github.com/user-attachments/assets/534b557e-7b93-4f2e-b3b3-3a0c7cf57c40" /> | <img width="1024" height="1024" alt="prompt2_nocfg" src="https://github.com/user-attachments/assets/d0f33526-f95d-47db-b5a6-6200bfa151f9" /> | | CFG | <img width="1024" height="1024" alt="prompt0_cfg" src="https://github.com/user-attachments/assets/83234675-eb47-4785-abe1-0f07dd854f1c" /> | <img width="1024" height="1024" alt="prompt1_cfg" src="https://github.com/user-attachments/assets/5e76f3e7-0ca3-47a4-a0ef-3c7e983e8c2c" /> | <img width="1024" height="1024" alt="prompt2_cfg" src="https://github.com/user-attachments/assets/c8cbe367-d96e-4559-a201-48e8dc3d18ee" /> |
This PR reuses the trainer's parallelization to perform batched inference on Flux. It follows up from and addresses comments on @CarlosGomes98 pytorch#1227. This irons out some optimizations for better code clarity. In this implementation saves images locally from each rank but maintains the global information such as prompt and idx of each image. It also removes unnecessary padding and handles prompt sets that are not divisible by the batch size.
This PR adds the `FluxStateDictAdapter`, allowing us to convert checkpoints to and from HF. Additional changes: - Modifies `download_hf_assets` script to support downloading diffusion-type safetensor files - Registers Flux's `TrainSpec` in `convert_from_hf` and `convert_to_hf` so that conversion script can be reused - e.g. `python ./scripts/checkpoint_conversion/convert_from_hf.py ./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux --model_flavor flux-dev` Tests: Performing KL divergence test on the forward pass of converted weights loaded in `torchtitan` and HF weights loaded with HF `FluxTransformer2DModel`, we get: ``` Average loss for test from_hf is 7.233546986222528e-13 ``` Addiitonally, we can now run inference with HF weights to verify changes made in pytorch#1548 ### Batched Inference on TorchTitan: | | prompt0 | prompt1 | prompt2 | | --- | --- | --- | --- | | no CFG | <img width="1024" height="1024" alt="prompt0_nocfg" src="https://github.com/user-attachments/assets/421fab49-239a-4ca2-b51a-16823d89acfd" /> | <img width="1024" height="1024" alt="prompt1_nocfg" src="https://github.com/user-attachments/assets/534b557e-7b93-4f2e-b3b3-3a0c7cf57c40" /> | <img width="1024" height="1024" alt="prompt2_nocfg" src="https://github.com/user-attachments/assets/d0f33526-f95d-47db-b5a6-6200bfa151f9" /> | | CFG | <img width="1024" height="1024" alt="prompt0_cfg" src="https://github.com/user-attachments/assets/83234675-eb47-4785-abe1-0f07dd854f1c" /> | <img width="1024" height="1024" alt="prompt1_cfg" src="https://github.com/user-attachments/assets/5e76f3e7-0ca3-47a4-a0ef-3c7e983e8c2c" /> | <img width="1024" height="1024" alt="prompt2_cfg" src="https://github.com/user-attachments/assets/c8cbe367-d96e-4559-a201-48e8dc3d18ee" /> |
This PR reuses the trainer's parallelization to perform batched inference on Flux. It follows up from and addresses comments on @CarlosGomes98 pytorch#1227. This irons out some optimizations for better code clarity. In this implementation saves images locally from each rank but maintains the global information such as prompt and idx of each image. It also removes unnecessary padding and handles prompt sets that are not divisible by the batch size.
This PR adds the `FluxStateDictAdapter`, allowing us to convert checkpoints to and from HF. Additional changes: - Modifies `download_hf_assets` script to support downloading diffusion-type safetensor files - Registers Flux's `TrainSpec` in `convert_from_hf` and `convert_to_hf` so that conversion script can be reused - e.g. `python ./scripts/checkpoint_conversion/convert_from_hf.py ./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux --model_flavor flux-dev` Tests: Performing KL divergence test on the forward pass of converted weights loaded in `torchtitan` and HF weights loaded with HF `FluxTransformer2DModel`, we get: ``` Average loss for test from_hf is 7.233546986222528e-13 ``` Addiitonally, we can now run inference with HF weights to verify changes made in pytorch#1548 ### Batched Inference on TorchTitan: | | prompt0 | prompt1 | prompt2 | | --- | --- | --- | --- | | no CFG | <img width="1024" height="1024" alt="prompt0_nocfg" src="https://github.com/user-attachments/assets/421fab49-239a-4ca2-b51a-16823d89acfd" /> | <img width="1024" height="1024" alt="prompt1_nocfg" src="https://github.com/user-attachments/assets/534b557e-7b93-4f2e-b3b3-3a0c7cf57c40" /> | <img width="1024" height="1024" alt="prompt2_nocfg" src="https://github.com/user-attachments/assets/d0f33526-f95d-47db-b5a6-6200bfa151f9" /> | | CFG | <img width="1024" height="1024" alt="prompt0_cfg" src="https://github.com/user-attachments/assets/83234675-eb47-4785-abe1-0f07dd854f1c" /> | <img width="1024" height="1024" alt="prompt1_cfg" src="https://github.com/user-attachments/assets/5e76f3e7-0ca3-47a4-a0ef-3c7e983e8c2c" /> | <img width="1024" height="1024" alt="prompt2_cfg" src="https://github.com/user-attachments/assets/c8cbe367-d96e-4559-a201-48e8dc3d18ee" /> |
This PR reuses the trainer's parallelization to perform batched inference on Flux. It follows up from and addresses comments on @CarlosGomes98 #1227. This irons out some optimizations for better code clarity.
In this implementation saves images locally from each rank but maintains the global information such as prompt and idx of each image. It also removes unnecessary padding and handles prompt sets that are not divisible by the batch size.