Skip to content

Add emulate in float8 and relative checks#1214

Merged
mori360 merged 22 commits intopytorch:mainfrom
mori360:add_fp8_emulate
May 28, 2025
Merged

Add emulate in float8 and relative checks#1214
mori360 merged 22 commits intopytorch:mainfrom
mori360:add_fp8_emulate

Conversation

@mori360
Copy link
Contributor

@mori360 mori360 commented May 21, 2025

Add emulate in float8, to enable test on older hardware.

Change relative warnings

Test result:
Test locally on 8 H100 server.
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd
Screenshot 2025-05-21 at 2 38 39 PM

CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --float8.emulate
Screenshot 2025-05-21 at 2 39 01 PM

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 21, 2025
@mori360 mori360 changed the title Add emulate in float 8 and relative checks Add emulate in float8 and relative checks May 21, 2025
@mori360 mori360 marked this pull request as ready for review May 22, 2025 03:09
@mori360 mori360 requested review from tianyu-l and vkuzo May 22, 2025 03:09
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! I left some inline comments.

wandb
fsspec
tyro
torchao
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the recommended way of installing torchao is still via nightly, similar to how we install pytorch nightly for CI
https://github.com/pytorch/torchtitan/blob/main/.github/workflows/integration_test_8gpu.yaml#L39
but for torchao
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git

"To enable support on older hardware, set `float8.emulate` to True.",
)
return
elif float8_config.emulate and job_config.training.compile:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if emulate+compile works on H100? Since the original comment from @vkuzo is

torch.compile with float8 dtypes is not going to work on older hardware, so the emulation can only be used in eager mode.

Copy link
Contributor Author

@mori360 mori360 May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will have some tests on it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test to be good, remove this exception

Comment on lines 455 to 457
Whether to run on earlier hardware in CI test.
torch.compile with float8 dtypes is not going to work on older hardware, so the emulation can
only be used in eager mode.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Whether to run on earlier hardware in CI test.
torch.compile with float8 dtypes is not going to work on older hardware, so the emulation can
only be used in eager mode.
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, as the current CI does have sm_90 capability, required by Float8.
Not compatible with torch.compile.

This is assuming torch.compile+emulate don't work on >= H100 either. If not we'll need to further adjust code and helper message.

return
elif float8_config.emulate and job_config.training.compile:
logger.warning(
"Failed to run on emulate with compile on, please disable compile to allow on emulate.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just raise an exception if the configurations combination is not runnable.

@mori360 mori360 marked this pull request as draft May 22, 2025 17:35
@mori360 mori360 marked this pull request as ready for review May 22, 2025 18:41
@mori360 mori360 requested review from fegin and tianyu-l May 22, 2025 18:41
Comment on lines 455 to 457
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only,
as the current CI does have sm_90 capability, required by Float8.
Not compatible with torch.compile.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only,
as the current CI does have sm_90 capability, required by Float8.
Not compatible with torch.compile.
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only,
as the current CI does not have sm_89 capability, required by Float8.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you make this update to the doc string? otherwise it seems inaccurate

Suggested change
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only,
as the current CI does have sm_90 capability, required by Float8.
Not compatible with torch.compile.
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only,
as the current CI does not have sm_89 capability, required by Float8.

logger.warning(
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later."
"To enable support on older hardware, set `float8.emulate` to True.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"To enable support on older hardware, set `float8.emulate` to True.",
"To enable testing on older hardware, set `float8.emulate` to True in eager mode.",


float8_config: Float8 = job_config.float8
if not has_cuda_capability(8, 9):
if not has_cuda_capability(8, 9) and not float8_config.emulate:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On sm < 89, we can't enable torch.compile with/without emulate, right? If so let's do

Suggested change
if not has_cuda_capability(8, 9) and not float8_config.emulate:
if not has_cuda_capability(8, 9) and (job_config.training.compile or not float8_config.emulate):

Also it's a bit hard to read. A better way may be

if has_cuda_capability(8, 9) or (float8_config.emulate and not job_config.training.compile): pass
else: raise ValueError(...)

@mori360 mori360 marked this pull request as draft May 23, 2025 17:15
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CPU CI error is because we change warning to exception when sm < 89.
I think we can just add the emulate flag to https://github.com/pytorch/torchtitan/blob/main/tests/unit_tests/test_model_converter.py#L42

@mori360 mori360 marked this pull request as ready for review May 28, 2025 03:08
@mori360 mori360 requested a review from tianyu-l May 28, 2025 03:09
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please address final comments before merge.


pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu

pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cpu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, can we specify USE_CPP=0 here too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it works.

Comment on lines 455 to 457
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only,
as the current CI does have sm_90 capability, required by Float8.
Not compatible with torch.compile.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you make this update to the doc string? otherwise it seems inaccurate

Suggested change
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only,
as the current CI does have sm_90 capability, required by Float8.
Not compatible with torch.compile.
If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only,
as the current CI does not have sm_89 capability, required by Float8.

@mori360 mori360 merged commit 594a120 into pytorch:main May 28, 2025
6 checks passed
@mori360 mori360 deleted the add_fp8_emulate branch May 28, 2025 17:44
wwwjn pushed a commit to wwwjn/torchtitan that referenced this pull request Jun 2, 2025
Add
[emulate](https://github.com/pytorch/ao/blob/554cb60c750e6ef31bbcafec74bb76a4578902da/torchao/float8/config.py#L193)
in float8, to enable test on older hardware.

Change relative warnings

Test result:
Test locally on 8 H100 server.
`CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --model.converters="float8"
--float8.enable_fsdp_float8_all_gather
--float8.precompute_float8_dynamic_scale_for_fsdp
--float8.force_recompute_fp8_weight_in_bwd`
<img width="1127" alt="Screenshot 2025-05-21 at 2 38 39 PM"
src="https://github.com/user-attachments/assets/c15fcabb-d7cd-4c96-8ff4-9fe5a2bc5246"
/>


`CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --model.converters="float8"
--float8.enable_fsdp_float8_all_gather
--float8.precompute_float8_dynamic_scale_for_fsdp
--float8.force_recompute_fp8_weight_in_bwd --float8.emulate`
<img width="1127" alt="Screenshot 2025-05-21 at 2 39 01 PM"
src="https://github.com/user-attachments/assets/9227c839-fbe6-45d0-a919-6b62ac66863a"
/>
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
Add
[emulate](https://github.com/pytorch/ao/blob/554cb60c750e6ef31bbcafec74bb76a4578902da/torchao/float8/config.py#L193)
in float8, to enable test on older hardware.

Change relative warnings

Test result:
Test locally on 8 H100 server.
`CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --model.converters="float8"
--float8.enable_fsdp_float8_all_gather
--float8.precompute_float8_dynamic_scale_for_fsdp
--float8.force_recompute_fp8_weight_in_bwd`
<img width="1127" alt="Screenshot 2025-05-21 at 2 38 39 PM"
src="https://github.com/user-attachments/assets/c15fcabb-d7cd-4c96-8ff4-9fe5a2bc5246"
/>


`CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --model.converters="float8"
--float8.enable_fsdp_float8_all_gather
--float8.precompute_float8_dynamic_scale_for_fsdp
--float8.force_recompute_fp8_weight_in_bwd --float8.emulate`
<img width="1127" alt="Screenshot 2025-05-21 at 2 39 01 PM"
src="https://github.com/user-attachments/assets/9227c839-fbe6-45d0-a919-6b62ac66863a"
/>
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
Add
[emulate](https://github.com/pytorch/ao/blob/554cb60c750e6ef31bbcafec74bb76a4578902da/torchao/float8/config.py#L193)
in float8, to enable test on older hardware.

Change relative warnings

Test result:
Test locally on 8 H100 server.
`CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --model.converters="float8"
--float8.enable_fsdp_float8_all_gather
--float8.precompute_float8_dynamic_scale_for_fsdp
--float8.force_recompute_fp8_weight_in_bwd`
<img width="1127" alt="Screenshot 2025-05-21 at 2 38 39 PM"
src="https://github.com/user-attachments/assets/c15fcabb-d7cd-4c96-8ff4-9fe5a2bc5246"
/>


`CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --model.converters="float8"
--float8.enable_fsdp_float8_all_gather
--float8.precompute_float8_dynamic_scale_for_fsdp
--float8.force_recompute_fp8_weight_in_bwd --float8.emulate`
<img width="1127" alt="Screenshot 2025-05-21 at 2 39 01 PM"
src="https://github.com/user-attachments/assets/9227c839-fbe6-45d0-a919-6b62ac66863a"
/>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants