Skip to content

Implement debug mode for GarbageCollection#1230

Merged
fegin merged 3 commits intomainfrom
gc_debug
May 28, 2025
Merged

Implement debug mode for GarbageCollection#1230
fegin merged 3 commits intomainfrom
gc_debug

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented May 28, 2025

Summary:
When the debug mode is turned on, 1) warn_tensor_cycles() will be called on rank0 and 2) gc.collect() will be called every iteration to understand the possible memory (tensor) leakage.

Reference: https://pytorch.org/blog/understanding-gpu-memory-2/

The current TorchTitan shows memory leakage:

CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh  --training.steps=20
[rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - step:  1  loss: 12.2721  memory: 42.16GiB(44.38%)  tps: 1,677  tflops: 97.12  mfu: 9.82%
[rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - Synchronizing and adjusting timeout for
all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-05-27 20:44:33,456 - root - INFO - step: 10  loss: 10.0301  memory: 68.72GiB(72.34%)  tps: 5,040  tflops: 291.86  mfu: 29.51%
[rank0]:[titan] 2025-05-27 20:44:48,141 - root - INFO - step: 20  loss:  8.4547  memory: 90.29GiB(95.03%)  tps: 5,579  tflops: 323.12  mfu: 32.67%
[rank0]:[titan] 2025-05-27 20:44:48,150 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:[titan] 2025-05-27 20:44:50,152 - root - INFO - Training completed
[rank0]:[titan] 2025-05-27 20:44:50,569 - root - INFO - Process group destroyed.

With this PR, we can use the following command to debug

CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh  --training.steps=20 --training.gc_debug
[rank0]:[titan] 2025-05-27 20:49:04,858 - root - INFO - Force GC to perform collection to get the debug information.
[rank0]:[rank0]:W0527 20:49:05.414000 2423031 torch/utils/viz/_cycles.py:498] Reference cycle includes a CUDA Tensor see visualization of cycle /tmp/tmp9oginps8.html
[rank0]:[rank0]:W0527 20:49:05.687000 2423031 torch/utils/viz/_cycles.py:59] CUDA Memory changed during GC, 2147483648 bytes freed.
[rank0]:[titan] 2025-05-27 20:49:07,157 - root - INFO - step: 20  loss:  8.3943  memory: 49.66GiB(52.27%)  tps: 3,573  tflops: 206.93  mfu: 20.92%
[rank0]:[titan] 2025-05-27 20:49:07,167 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:[titan] 2025-05-27 20:49:09,169 - root - INFO - Training completed
[rank0]:[titan] 2025-05-27 20:49:10,198 - root - INFO - Process group destroyed.

warn_tensor_cycles() shows that 1) there are reference cycles that include CUDA tensors, 2) 2GB GPU memory is freed when gc.collect() is called. And the visualization shows the reference cycle seems to be from activation checkpointing.

Screenshot 2025-05-27 at 8 52 00 PM

Summary:
When the debug mode is turned on, 1) warn_tensor_cycles() will be called on rank0 and 2) gc.collect() will be called every iteration to understand the possible memory (tensor) leakage.
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 28, 2025
@fegin
Copy link
Contributor Author

fegin commented May 28, 2025

cc., @soulitzer, could you help review the potential memory leak issue related to activation checkpointing?

@fegin fegin requested a review from tianyu-l May 28, 2025 03:57
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me. Had a question on the generation level used for different scenarios.

def run(self, step_count: int):
if self.debug:
logger.info("Force GC to perform collection to obtain debug information.")
gc.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate a bit on the choice of the generation arg of gc.collect()?

  • I guess we have to use default generation=2 under debug mode to identify all reference cycles -- so generation=1 won't work?
  • I'm assuming why we are using gc.collect(1) for periodic GC just to make it not so heavy?

Copy link
Contributor Author

@fegin fegin May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we have to use default generation=2 under debug mode to identify all reference cycles -- so generation=1 won't work?

Yes

I'm assuming why we are using gc.collect(1) for periodic GC just to make it not so heavy?

Also yes


def run(self, step_count: int):
if self.debug:
logger.info("Force GC to perform collection to obtain debug information.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially we can make generation as an arg to GarbageCollection.collect so that this line can be the reason.

@fegin fegin merged commit f8f545c into main May 28, 2025
6 checks passed
fegin added a commit that referenced this pull request May 28, 2025
Improve the code based on the #1230 (comment)
@fegin fegin mentioned this pull request May 28, 2025
fegin added a commit that referenced this pull request May 28, 2025
Improve the code based on the
#1230 (comment)
@soulitzer
Copy link
Contributor

@fegin Hmm I wasn't able to reproduce the leak

From taking a look at the graph, it looks like the reference cycle is from the dynamo wrapper which is keeping the recompute_fn alive, and the recompute_fn is keeping alive the storage dict for SAC.
It is interesting that the storage dict is keeping any tensors alive at all though because the tensors should be popped as the recompute uses the tensors cached by SAC...

@tianyu-l tianyu-l deleted the gc_debug branch May 28, 2025 20:31
@soulitzer
Copy link
Contributor

Update: root cause appears to be pytorch/pytorch#153300 where the compile disable decorator leads to a reference cycle being created.

@fegin
Copy link
Contributor Author

fegin commented May 29, 2025

Created PyTorch issue: pytorch/pytorch#154642

wwwjn pushed a commit to wwwjn/torchtitan that referenced this pull request Jun 2, 2025
Summary:
When the debug mode is turned on, 1) warn_tensor_cycles() will be called
on rank0 and 2) gc.collect() will be called every iteration to
understand the possible memory (tensor) leakage.

Reference: https://pytorch.org/blog/understanding-gpu-memory-2/

The current TorchTitan shows memory leakage:

```
CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh  --training.steps=20
```
```
[rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - step:  1  loss: 12.2721  memory: 42.16GiB(44.38%)  tps: 1,677  tflops: 97.12  mfu: 9.82%
[rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - Synchronizing and adjusting timeout for
all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-05-27 20:44:33,456 - root - INFO - step: 10  loss: 10.0301  memory: 68.72GiB(72.34%)  tps: 5,040  tflops: 291.86  mfu: 29.51%
[rank0]:[titan] 2025-05-27 20:44:48,141 - root - INFO - step: 20  loss:  8.4547  memory: 90.29GiB(95.03%)  tps: 5,579  tflops: 323.12  mfu: 32.67%
[rank0]:[titan] 2025-05-27 20:44:48,150 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:[titan] 2025-05-27 20:44:50,152 - root - INFO - Training completed
[rank0]:[titan] 2025-05-27 20:44:50,569 - root - INFO - Process group destroyed.
```

With this PR, we can use the following command to debug
```
CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh  --training.steps=20 --training.gc_debug
```
```
[rank0]:[titan] 2025-05-27 20:49:04,858 - root - INFO - Force GC to perform collection to get the debug information.
[rank0]:[rank0]:W0527 20:49:05.414000 2423031 torch/utils/viz/_cycles.py:498] Reference cycle includes a CUDA Tensor see visualization of cycle /tmp/tmp9oginps8.html
[rank0]:[rank0]:W0527 20:49:05.687000 2423031 torch/utils/viz/_cycles.py:59] CUDA Memory changed during GC, 2147483648 bytes freed.
[rank0]:[titan] 2025-05-27 20:49:07,157 - root - INFO - step: 20  loss:  8.3943  memory: 49.66GiB(52.27%)  tps: 3,573  tflops: 206.93  mfu: 20.92%
[rank0]:[titan] 2025-05-27 20:49:07,167 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:[titan] 2025-05-27 20:49:09,169 - root - INFO - Training completed
[rank0]:[titan] 2025-05-27 20:49:10,198 - root - INFO - Process group destroyed.
```

`warn_tensor_cycles()` shows that 1) there are reference cycles that
include CUDA tensors, 2) 2GB GPU memory is freed when `gc.collect()` is
called. And the visualization shows the reference cycle seems to be from
activation checkpointing.

<img width="1597" alt="Screenshot 2025-05-27 at 8 52 00 PM"
src="https://github.com/user-attachments/assets/2e241baa-16fe-4e87-acce-fb72710babc2"
/>
wwwjn pushed a commit to wwwjn/torchtitan that referenced this pull request Jun 2, 2025
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
Summary:
When the debug mode is turned on, 1) warn_tensor_cycles() will be called
on rank0 and 2) gc.collect() will be called every iteration to
understand the possible memory (tensor) leakage.

Reference: https://pytorch.org/blog/understanding-gpu-memory-2/

The current TorchTitan shows memory leakage:

```
CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh  --training.steps=20
```
```
[rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - step:  1  loss: 12.2721  memory: 42.16GiB(44.38%)  tps: 1,677  tflops: 97.12  mfu: 9.82%
[rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - Synchronizing and adjusting timeout for
all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-05-27 20:44:33,456 - root - INFO - step: 10  loss: 10.0301  memory: 68.72GiB(72.34%)  tps: 5,040  tflops: 291.86  mfu: 29.51%
[rank0]:[titan] 2025-05-27 20:44:48,141 - root - INFO - step: 20  loss:  8.4547  memory: 90.29GiB(95.03%)  tps: 5,579  tflops: 323.12  mfu: 32.67%
[rank0]:[titan] 2025-05-27 20:44:48,150 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:[titan] 2025-05-27 20:44:50,152 - root - INFO - Training completed
[rank0]:[titan] 2025-05-27 20:44:50,569 - root - INFO - Process group destroyed.
```

With this PR, we can use the following command to debug
```
CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh  --training.steps=20 --training.gc_debug
```
```
[rank0]:[titan] 2025-05-27 20:49:04,858 - root - INFO - Force GC to perform collection to get the debug information.
[rank0]:[rank0]:W0527 20:49:05.414000 2423031 torch/utils/viz/_cycles.py:498] Reference cycle includes a CUDA Tensor see visualization of cycle /tmp/tmp9oginps8.html
[rank0]:[rank0]:W0527 20:49:05.687000 2423031 torch/utils/viz/_cycles.py:59] CUDA Memory changed during GC, 2147483648 bytes freed.
[rank0]:[titan] 2025-05-27 20:49:07,157 - root - INFO - step: 20  loss:  8.3943  memory: 49.66GiB(52.27%)  tps: 3,573  tflops: 206.93  mfu: 20.92%
[rank0]:[titan] 2025-05-27 20:49:07,167 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:[titan] 2025-05-27 20:49:09,169 - root - INFO - Training completed
[rank0]:[titan] 2025-05-27 20:49:10,198 - root - INFO - Process group destroyed.
```

`warn_tensor_cycles()` shows that 1) there are reference cycles that
include CUDA tensors, 2) 2GB GPU memory is freed when `gc.collect()` is
called. And the visualization shows the reference cycle seems to be from
activation checkpointing.

<img width="1597" alt="Screenshot 2025-05-27 at 8 52 00 PM"
src="https://github.com/user-attachments/assets/2e241baa-16fe-4e87-acce-fb72710babc2"
/>
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
Summary:
When the debug mode is turned on, 1) warn_tensor_cycles() will be called
on rank0 and 2) gc.collect() will be called every iteration to
understand the possible memory (tensor) leakage.

Reference: https://pytorch.org/blog/understanding-gpu-memory-2/

The current TorchTitan shows memory leakage:

```
CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh  --training.steps=20
```
```
[rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - step:  1  loss: 12.2721  memory: 42.16GiB(44.38%)  tps: 1,677  tflops: 97.12  mfu: 9.82%
[rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - Synchronizing and adjusting timeout for
all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-05-27 20:44:33,456 - root - INFO - step: 10  loss: 10.0301  memory: 68.72GiB(72.34%)  tps: 5,040  tflops: 291.86  mfu: 29.51%
[rank0]:[titan] 2025-05-27 20:44:48,141 - root - INFO - step: 20  loss:  8.4547  memory: 90.29GiB(95.03%)  tps: 5,579  tflops: 323.12  mfu: 32.67%
[rank0]:[titan] 2025-05-27 20:44:48,150 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:[titan] 2025-05-27 20:44:50,152 - root - INFO - Training completed
[rank0]:[titan] 2025-05-27 20:44:50,569 - root - INFO - Process group destroyed.
```

With this PR, we can use the following command to debug
```
CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh  --training.steps=20 --training.gc_debug
```
```
[rank0]:[titan] 2025-05-27 20:49:04,858 - root - INFO - Force GC to perform collection to get the debug information.
[rank0]:[rank0]:W0527 20:49:05.414000 2423031 torch/utils/viz/_cycles.py:498] Reference cycle includes a CUDA Tensor see visualization of cycle /tmp/tmp9oginps8.html
[rank0]:[rank0]:W0527 20:49:05.687000 2423031 torch/utils/viz/_cycles.py:59] CUDA Memory changed during GC, 2147483648 bytes freed.
[rank0]:[titan] 2025-05-27 20:49:07,157 - root - INFO - step: 20  loss:  8.3943  memory: 49.66GiB(52.27%)  tps: 3,573  tflops: 206.93  mfu: 20.92%
[rank0]:[titan] 2025-05-27 20:49:07,167 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:[titan] 2025-05-27 20:49:09,169 - root - INFO - Training completed
[rank0]:[titan] 2025-05-27 20:49:10,198 - root - INFO - Process group destroyed.
```

`warn_tensor_cycles()` shows that 1) there are reference cycles that
include CUDA tensors, 2) 2GB GPU memory is freed when `gc.collect()` is
called. And the visualization shows the reference cycle seems to be from
activation checkpointing.

<img width="1597" alt="Screenshot 2025-05-27 at 8 52 00 PM"
src="https://github.com/user-attachments/assets/2e241baa-16fe-4e87-acce-fb72710babc2"
/>
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants