TorchTitan e2e test on torchcomms device mesh#1847
Conversation
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763
36ab517 to
7c6435e
Compare
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763
7c6435e to
8998f1c
Compare
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763
tianyu-l
left a comment
There was a problem hiding this comment.
Thanks for the PR! The change looks interesting!
Is this PR for exploration or is it ready to ship to the community? If it's the former, could be start with branch / fork, instead of experiments?
Sorry to put a hold before we get more context.
tianyu-l
left a comment
There was a problem hiding this comment.
OK got clarifications offline. I think it's OK to host this experiment. To land, we'll need
- simplify the code, as I believe a lot existing components could be reused
- set up PoC for this folder (let's work together on this)
Sorry for the late reply, it's planed to ship to the community as a use case for torchcomms.
I was cleaning the code, the main change here is
How can I set up the PoC? |
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763
8998f1c to
a6b3a47
Compare
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763
a6b3a47 to
09e6610
Compare
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763
09e6610 to
288f37f
Compare
I'll do this with a PR shortly. Should we assign you as the PoC? |
Yeah, please. There would be some further changes to enable other parallelisms and relative tests. |
fegin
left a comment
There was a problem hiding this comment.
It is reasonable to duplicate ParallelDims._build_mesh_without_ep but Trainer.init seems to be mostly the same. And Trainer.init is very long. So it is not easy to debug the difference. Can you point out what changes in Trainer.init? We can brainstorm how to further minimize the duplications.
| --- | ||
| #### Example | ||
| ```bash | ||
| TEST_BACKEND=nccl ./run_train.sh --model.name torchcomms |
There was a problem hiding this comment.
This doesn't seem to be correct. You will at least need to specify CONFIG_FILE.
| - Training with `torchcomms.new_comm` | ||
| - Device mesh initialization with `torchcomms.init_device_mesh` | ||
| - **Composability Testing** | ||
| - Integration and testing with `fully_shard` (FSDP) |
There was a problem hiding this comment.
Is this FSDP2 only? I thought you also verified it with TP. cc., @fduwjj
There was a problem hiding this comment.
We are working on ND now, will update readme later
There was a problem hiding this comment.
@fegin there are still some gaps on the N-D side, so we aim at first merging this PR with 1D only. This is to scale down the scope of this PR and then we will have more PRs down the road.
| # init distributed and build meshes | ||
| dist_utils.init_distributed( | ||
| job_config.comm, | ||
| enable_cpu_backend=job_config.training.enable_cpu_offload, | ||
| base_folder=job_config.job.dump_folder, | ||
| ) | ||
| world_size = int(os.environ["WORLD_SIZE"]) | ||
| parallelism_config = job_config.parallelism | ||
| self.parallel_dims = parallel_dims = ParallelDimsForComms( | ||
| dp_shard=parallelism_config.data_parallel_shard_degree, | ||
| dp_replicate=parallelism_config.data_parallel_replicate_degree, | ||
| cp=parallelism_config.context_parallel_degree, | ||
| tp=parallelism_config.tensor_parallel_degree, | ||
| pp=parallelism_config.pipeline_parallel_degree, | ||
| ep=parallelism_config.expert_parallel_degree, | ||
| etp=parallelism_config.expert_tensor_parallel_degree, | ||
| world_size=world_size, | ||
| ) |
There was a problem hiding this comment.
iiuc, only this part of the initialization is changed. Is this correct? Or can you point out some other things you changed?
There was a problem hiding this comment.
Yeah, had some other changes before, but now that's the only changes.
Will try some way to call ParallelDimsForComms here but avoiding copy train.init
There was a problem hiding this comment.
You can let the original Trainer have one class variable called parallel_dims_cls and use that variable in the init to construct self.parallel_dims = parallel_dims. Then you can just create a CommTrainer and replace that class variable.
Another approach is to make the following code as a method, def create_parallel_dims(self, config) -> None:.
self.parallel_dims = parallel_dims = ParallelDimsForComms(
dp_shard=parallelism_config.data_parallel_shard_degree,
dp_replicate=parallelism_config.data_parallel_replicate_degree,
cp=parallelism_config.context_parallel_degree,
tp=parallelism_config.tensor_parallel_degree,
pp=parallelism_config.pipeline_parallel_degree,
ep=parallelism_config.expert_parallel_degree,
etp=parallelism_config.expert_tensor_parallel_degree,
world_size=world_size,
)
There was a problem hiding this comment.
Sounds OK. I prefer the second option as it sounds a bit more straightforward. Maybe should call it _create_parallel_dims as it's not supposed to be called outside.
fduwjj
left a comment
There was a problem hiding this comment.
Can you run a job and paste the loss curve from tensor board here?
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763
7de8101 to
b270b5a
Compare
|
Also we will have more converge and perf test down the road as follow-up PRs. |
| @@ -0,0 +1,20 @@ | |||
| # TorchTitan & TorchComms Composability Testing | |||
|
|
|||
| This repository provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch. | |||
There was a problem hiding this comment.
This is currently in bold font and look a bit obtrusive. Could you adjust them to use plain font?
| @@ -0,0 +1,20 @@ | |||
| # TorchTitan & TorchComms Composability Testing | |||
|
|
|||
| This repository provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch. | |||
There was a problem hiding this comment.
| This repository provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch. | |
| This folder provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch. |
|
|
||
| This repository provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch. | ||
| --- | ||
| #### Example |
There was a problem hiding this comment.
mention that the command below uses Llama 3 as an example, but should work on all models.
| --- | ||
| #### Example | ||
| ```bash | ||
| TEST_BACKEND={backend} TRAIN_FILE=torchtitan.experiments.torchcomms.train ./run_train.sh --model.name torchcomms |
There was a problem hiding this comment.
TEST_BACKEND={backend}
What should this be?
There was a problem hiding this comment.
users can input backend they want to use, e.g. nccl or other backend
It's a bit confusing here, will change to TEST_BACKEND=nccl
There was a problem hiding this comment.
users can input backend they want to use, e.g. nccl or other backend
It's a bit confusing here, will change to TEST_BACKEND=nccl
There was a problem hiding this comment.
Can we mention all the available backends? From the readme it's hard to tell what people should put here.
There was a problem hiding this comment.
let's mention nccl, gloo or any other user defined customized backend for now. Also let's mention that the user customized backend needs to implement torchComm wrapper. (We just don't mention the backend which cannot be mentioned at this moment.)
| from torchtitan.models.llama3.infra.parallelize import parallelize_llama | ||
| from torchtitan.protocols.train_spec import register_train_spec, TrainSpec | ||
|
|
||
| register_train_spec( |
There was a problem hiding this comment.
why do you need to register this TrainSpec?
| --- | ||
| #### Example | ||
| ```bash | ||
| TEST_BACKEND={backend} TRAIN_FILE=torchtitan.experiments.torchcomms.train ./run_train.sh --model.name torchcomms |
There was a problem hiding this comment.
Let's set CONFIG_FILE here, too. You can refer to examples in main README.md
torchtitan/train.py
Outdated
| f"(warmup {job_config.lr_scheduler.warmup_steps})" | ||
| ) | ||
|
|
||
| def create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims: |
There was a problem hiding this comment.
| def create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims: | |
| def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims: |
| class CommsTrainer(Trainer): | ||
| parallel_dims: ParallelDimsForComms | ||
|
|
||
| def create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims: |
There was a problem hiding this comment.
| def create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims: | |
| def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims: |
| from .parallel_dims import ParallelDimsForComms | ||
|
|
||
|
|
||
| class CommsTrainer(Trainer): |
There was a problem hiding this comment.
| class CommsTrainer(Trainer): | |
| class TorchCommsTrainer(Trainer): |
|
|
||
|
|
||
| @dataclass | ||
| class ParallelDimsForComms(ParallelDims): |
There was a problem hiding this comment.
| class ParallelDimsForComms(ParallelDims): | |
| class TorchCommsParallelDims(ParallelDims): |
There was a problem hiding this comment.
shall we remove this file?
There was a problem hiding this comment.
I think we can remove this file for now.
| @@ -0,0 +1,22 @@ | |||
| # TorchTitan & TorchComms Composability Testing | |||
|
|
|||
| This folder provides a framework for composability testing with TorchComms and distributed training in TorchTitan. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch. | |||
There was a problem hiding this comment.
seems font hasn't been fixed
torchcomms
Outdated
There was a problem hiding this comment.
For now, we cannot mention too much details. We will add more context when it goes public. We need to merge this PR first so that the titan integration can go with the release of torchcomm.
@mori360 let's add a TODO here to add more explanation once the torchcomm goes public.
|
Looks like you have lint error as well? |
fduwjj
left a comment
There was a problem hiding this comment.
Thanks for doing this, looks good to me now.
| - Integration and testing with `fully_shard` (FSDP) | ||
| --- | ||
| ### To Be Added | ||
| - Integration and testing with additional parallelism strategies (e.g., tensor, pipeline, model parallelism) other than fully_shard |
There was a problem hiding this comment.
can you remove model parallelism or replace it with context parallelism? Thanks
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763 Test plan: TEST_BACKEND=nccl TRAIN_FILE=torchtitan.experiments.torchcomms.train ./run_train.sh --model.name torchcomms Loss curve: running 1000 steps on llama3_8b.toml <img width="1095" height="469" alt="Screenshot 2025-10-13 at 4 14 46 PM" src="https://github.com/user-attachments/assets/3d9ddf06-af76-44cf-ac75-b9f92e6d0f06" />
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763 Test plan: TEST_BACKEND=nccl TRAIN_FILE=torchtitan.experiments.torchcomms.train ./run_train.sh --model.name torchcomms Loss curve: running 1000 steps on llama3_8b.toml <img width="1095" height="469" alt="Screenshot 2025-10-13 at 4 14 46 PM" src="https://github.com/user-attachments/assets/3d9ddf06-af76-44cf-ac75-b9f92e6d0f06" />
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763 Test plan: TEST_BACKEND=nccl TRAIN_FILE=torchtitan.experiments.torchcomms.train ./run_train.sh --model.name torchcomms Loss curve: running 1000 steps on llama3_8b.toml <img width="1095" height="469" alt="Screenshot 2025-10-13 at 4 14 46 PM" src="https://github.com/user-attachments/assets/3d9ddf06-af76-44cf-ac75-b9f92e6d0f06" />
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763 Test plan: TEST_BACKEND=nccl TRAIN_FILE=torchtitan.experiments.torchcomms.train ./run_train.sh --model.name torchcomms Loss curve: running 1000 steps on llama3_8b.toml <img width="1095" height="469" alt="Screenshot 2025-10-13 at 4 14 46 PM" src="https://github.com/user-attachments/assets/3d9ddf06-af76-44cf-ac75-b9f92e6d0f06" />

Summary:
Composability testing with TorchComms and distributed training in TorchTitan.
torchcomms.new_commtorchcomms.init_device_meshfully_shardDifferential Revision: D82171763
Test plan:
TEST_BACKEND=nccl TRAIN_FILE=torchtitan.experiments.torchcomms.train ./run_train.sh --model.name torchcomms
Loss curve:
running 1000 steps on llama3_8b.toml