Add emulate in float8 and relative checks#1214
Conversation
tianyu-l
left a comment
There was a problem hiding this comment.
Thanks for working on this! I left some inline comments.
.ci/docker/requirements.txt
Outdated
| wandb | ||
| fsspec | ||
| tyro | ||
| torchao |
There was a problem hiding this comment.
I think the recommended way of installing torchao is still via nightly, similar to how we install pytorch nightly for CI
https://github.com/pytorch/torchtitan/blob/main/.github/workflows/integration_test_8gpu.yaml#L39
but for torchao
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
| "To enable support on older hardware, set `float8.emulate` to True.", | ||
| ) | ||
| return | ||
| elif float8_config.emulate and job_config.training.compile: |
There was a problem hiding this comment.
I wonder if emulate+compile works on H100? Since the original comment from @vkuzo is
torch.compile with float8 dtypes is not going to work on older hardware, so the emulation can only be used in eager mode.
There was a problem hiding this comment.
Will have some tests on it
There was a problem hiding this comment.
test to be good, remove this exception
torchtitan/config_manager.py
Outdated
| Whether to run on earlier hardware in CI test. | ||
| torch.compile with float8 dtypes is not going to work on older hardware, so the emulation can | ||
| only be used in eager mode. |
There was a problem hiding this comment.
| Whether to run on earlier hardware in CI test. | |
| torch.compile with float8 dtypes is not going to work on older hardware, so the emulation can | |
| only be used in eager mode. | |
| If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, as the current CI does have sm_90 capability, required by Float8. | |
| Not compatible with torch.compile. |
This is assuming torch.compile+emulate don't work on >= H100 either. If not we'll need to further adjust code and helper message.
| return | ||
| elif float8_config.emulate and job_config.training.compile: | ||
| logger.warning( | ||
| "Failed to run on emulate with compile on, please disable compile to allow on emulate.", |
There was a problem hiding this comment.
We should just raise an exception if the configurations combination is not runnable.
| If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, | ||
| as the current CI does have sm_90 capability, required by Float8. | ||
| Not compatible with torch.compile. |
There was a problem hiding this comment.
| If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, | |
| as the current CI does have sm_90 capability, required by Float8. | |
| Not compatible with torch.compile. | |
| If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, | |
| as the current CI does not have sm_89 capability, required by Float8. |
There was a problem hiding this comment.
could you make this update to the doc string? otherwise it seems inaccurate
| If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, | |
| as the current CI does have sm_90 capability, required by Float8. | |
| Not compatible with torch.compile. | |
| If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, | |
| as the current CI does not have sm_89 capability, required by Float8. |
| logger.warning( | ||
| "Failed to swap to Float8Linear because float8 is only supported on SM89 or later", | ||
| "Failed to swap to Float8Linear because float8 is only supported on SM89 or later." | ||
| "To enable support on older hardware, set `float8.emulate` to True.", |
There was a problem hiding this comment.
| "To enable support on older hardware, set `float8.emulate` to True.", | |
| "To enable testing on older hardware, set `float8.emulate` to True in eager mode.", |
|
|
||
| float8_config: Float8 = job_config.float8 | ||
| if not has_cuda_capability(8, 9): | ||
| if not has_cuda_capability(8, 9) and not float8_config.emulate: |
There was a problem hiding this comment.
On sm < 89, we can't enable torch.compile with/without emulate, right? If so let's do
| if not has_cuda_capability(8, 9) and not float8_config.emulate: | |
| if not has_cuda_capability(8, 9) and (job_config.training.compile or not float8_config.emulate): |
Also it's a bit hard to read. A better way may be
if has_cuda_capability(8, 9) or (float8_config.emulate and not job_config.training.compile): pass
else: raise ValueError(...)
tianyu-l
left a comment
There was a problem hiding this comment.
The CPU CI error is because we change warning to exception when sm < 89.
I think we can just add the emulate flag to https://github.com/pytorch/torchtitan/blob/main/tests/unit_tests/test_model_converter.py#L42
tianyu-l
left a comment
There was a problem hiding this comment.
LGTM. Please address final comments before merge.
.github/workflows/unit_test_cpu.yaml
Outdated
|
|
||
| pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu | ||
|
|
||
| pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cpu |
There was a problem hiding this comment.
curious, can we specify USE_CPP=0 here too?
| If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, | ||
| as the current CI does have sm_90 capability, required by Float8. | ||
| Not compatible with torch.compile. |
There was a problem hiding this comment.
could you make this update to the doc string? otherwise it seems inaccurate
| If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, | |
| as the current CI does have sm_90 capability, required by Float8. | |
| Not compatible with torch.compile. | |
| If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only, | |
| as the current CI does not have sm_89 capability, required by Float8. |
Add [emulate](https://github.com/pytorch/ao/blob/554cb60c750e6ef31bbcafec74bb76a4578902da/torchao/float8/config.py#L193) in float8, to enable test on older hardware. Change relative warnings Test result: Test locally on 8 H100 server. `CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd` <img width="1127" alt="Screenshot 2025-05-21 at 2 38 39 PM" src="https://github.com/user-attachments/assets/c15fcabb-d7cd-4c96-8ff4-9fe5a2bc5246" /> `CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --float8.emulate` <img width="1127" alt="Screenshot 2025-05-21 at 2 39 01 PM" src="https://github.com/user-attachments/assets/9227c839-fbe6-45d0-a919-6b62ac66863a" />
Add [emulate](https://github.com/pytorch/ao/blob/554cb60c750e6ef31bbcafec74bb76a4578902da/torchao/float8/config.py#L193) in float8, to enable test on older hardware. Change relative warnings Test result: Test locally on 8 H100 server. `CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd` <img width="1127" alt="Screenshot 2025-05-21 at 2 38 39 PM" src="https://github.com/user-attachments/assets/c15fcabb-d7cd-4c96-8ff4-9fe5a2bc5246" /> `CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --float8.emulate` <img width="1127" alt="Screenshot 2025-05-21 at 2 39 01 PM" src="https://github.com/user-attachments/assets/9227c839-fbe6-45d0-a919-6b62ac66863a" />
Add [emulate](https://github.com/pytorch/ao/blob/554cb60c750e6ef31bbcafec74bb76a4578902da/torchao/float8/config.py#L193) in float8, to enable test on older hardware. Change relative warnings Test result: Test locally on 8 H100 server. `CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd` <img width="1127" alt="Screenshot 2025-05-21 at 2 38 39 PM" src="https://github.com/user-attachments/assets/c15fcabb-d7cd-4c96-8ff4-9fe5a2bc5246" /> `CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --float8.emulate` <img width="1127" alt="Screenshot 2025-05-21 at 2 39 01 PM" src="https://github.com/user-attachments/assets/9227c839-fbe6-45d0-a919-6b62ac66863a" />
Add emulate in float8, to enable test on older hardware.
Change relative warnings
Test result:

Test locally on 8 H100 server.
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwdCONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --float8.emulate