Conversation
Summary: When the debug mode is turned on, 1) warn_tensor_cycles() will be called on rank0 and 2) gc.collect() will be called every iteration to understand the possible memory (tensor) leakage.
|
cc., @soulitzer, could you help review the potential memory leak issue related to activation checkpointing? |
tianyu-l
left a comment
There was a problem hiding this comment.
Sounds good to me. Had a question on the generation level used for different scenarios.
| def run(self, step_count: int): | ||
| if self.debug: | ||
| logger.info("Force GC to perform collection to obtain debug information.") | ||
| gc.collect() |
There was a problem hiding this comment.
Could you elaborate a bit on the choice of the generation arg of gc.collect()?
- I guess we have to use default
generation=2under debug mode to identify all reference cycles -- sogeneration=1won't work? - I'm assuming why we are using
gc.collect(1)for periodic GC just to make it not so heavy?
There was a problem hiding this comment.
I guess we have to use default generation=2 under debug mode to identify all reference cycles -- so generation=1 won't work?
Yes
I'm assuming why we are using gc.collect(1) for periodic GC just to make it not so heavy?
Also yes
|
|
||
| def run(self, step_count: int): | ||
| if self.debug: | ||
| logger.info("Force GC to perform collection to obtain debug information.") |
There was a problem hiding this comment.
potentially we can make generation as an arg to GarbageCollection.collect so that this line can be the reason.
Improve the code based on the #1230 (comment)
Improve the code based on the #1230 (comment)
|
@fegin Hmm I wasn't able to reproduce the leak From taking a look at the graph, it looks like the reference cycle is from the dynamo wrapper which is keeping the recompute_fn alive, and the recompute_fn is keeping alive the storage dict for SAC. |
|
Update: root cause appears to be pytorch/pytorch#153300 where the compile disable decorator leads to a reference cycle being created. |
|
Created PyTorch issue: pytorch/pytorch#154642 |
Summary: When the debug mode is turned on, 1) warn_tensor_cycles() will be called on rank0 and 2) gc.collect() will be called every iteration to understand the possible memory (tensor) leakage. Reference: https://pytorch.org/blog/understanding-gpu-memory-2/ The current TorchTitan shows memory leakage: ``` CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=20 ``` ``` [rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - step: 1 loss: 12.2721 memory: 42.16GiB(44.38%) tps: 1,677 tflops: 97.12 mfu: 9.82% [rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-05-27 20:44:33,456 - root - INFO - step: 10 loss: 10.0301 memory: 68.72GiB(72.34%) tps: 5,040 tflops: 291.86 mfu: 29.51% [rank0]:[titan] 2025-05-27 20:44:48,141 - root - INFO - step: 20 loss: 8.4547 memory: 90.29GiB(95.03%) tps: 5,579 tflops: 323.12 mfu: 32.67% [rank0]:[titan] 2025-05-27 20:44:48,150 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-05-27 20:44:50,152 - root - INFO - Training completed [rank0]:[titan] 2025-05-27 20:44:50,569 - root - INFO - Process group destroyed. ``` With this PR, we can use the following command to debug ``` CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=20 --training.gc_debug ``` ``` [rank0]:[titan] 2025-05-27 20:49:04,858 - root - INFO - Force GC to perform collection to get the debug information. [rank0]:[rank0]:W0527 20:49:05.414000 2423031 torch/utils/viz/_cycles.py:498] Reference cycle includes a CUDA Tensor see visualization of cycle /tmp/tmp9oginps8.html [rank0]:[rank0]:W0527 20:49:05.687000 2423031 torch/utils/viz/_cycles.py:59] CUDA Memory changed during GC, 2147483648 bytes freed. [rank0]:[titan] 2025-05-27 20:49:07,157 - root - INFO - step: 20 loss: 8.3943 memory: 49.66GiB(52.27%) tps: 3,573 tflops: 206.93 mfu: 20.92% [rank0]:[titan] 2025-05-27 20:49:07,167 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-05-27 20:49:09,169 - root - INFO - Training completed [rank0]:[titan] 2025-05-27 20:49:10,198 - root - INFO - Process group destroyed. ``` `warn_tensor_cycles()` shows that 1) there are reference cycles that include CUDA tensors, 2) 2GB GPU memory is freed when `gc.collect()` is called. And the visualization shows the reference cycle seems to be from activation checkpointing. <img width="1597" alt="Screenshot 2025-05-27 at 8 52 00 PM" src="https://github.com/user-attachments/assets/2e241baa-16fe-4e87-acce-fb72710babc2" />
Improve the code based on the pytorch#1230 (comment)
Summary: When the debug mode is turned on, 1) warn_tensor_cycles() will be called on rank0 and 2) gc.collect() will be called every iteration to understand the possible memory (tensor) leakage. Reference: https://pytorch.org/blog/understanding-gpu-memory-2/ The current TorchTitan shows memory leakage: ``` CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=20 ``` ``` [rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - step: 1 loss: 12.2721 memory: 42.16GiB(44.38%) tps: 1,677 tflops: 97.12 mfu: 9.82% [rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-05-27 20:44:33,456 - root - INFO - step: 10 loss: 10.0301 memory: 68.72GiB(72.34%) tps: 5,040 tflops: 291.86 mfu: 29.51% [rank0]:[titan] 2025-05-27 20:44:48,141 - root - INFO - step: 20 loss: 8.4547 memory: 90.29GiB(95.03%) tps: 5,579 tflops: 323.12 mfu: 32.67% [rank0]:[titan] 2025-05-27 20:44:48,150 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-05-27 20:44:50,152 - root - INFO - Training completed [rank0]:[titan] 2025-05-27 20:44:50,569 - root - INFO - Process group destroyed. ``` With this PR, we can use the following command to debug ``` CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=20 --training.gc_debug ``` ``` [rank0]:[titan] 2025-05-27 20:49:04,858 - root - INFO - Force GC to perform collection to get the debug information. [rank0]:[rank0]:W0527 20:49:05.414000 2423031 torch/utils/viz/_cycles.py:498] Reference cycle includes a CUDA Tensor see visualization of cycle /tmp/tmp9oginps8.html [rank0]:[rank0]:W0527 20:49:05.687000 2423031 torch/utils/viz/_cycles.py:59] CUDA Memory changed during GC, 2147483648 bytes freed. [rank0]:[titan] 2025-05-27 20:49:07,157 - root - INFO - step: 20 loss: 8.3943 memory: 49.66GiB(52.27%) tps: 3,573 tflops: 206.93 mfu: 20.92% [rank0]:[titan] 2025-05-27 20:49:07,167 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-05-27 20:49:09,169 - root - INFO - Training completed [rank0]:[titan] 2025-05-27 20:49:10,198 - root - INFO - Process group destroyed. ``` `warn_tensor_cycles()` shows that 1) there are reference cycles that include CUDA tensors, 2) 2GB GPU memory is freed when `gc.collect()` is called. And the visualization shows the reference cycle seems to be from activation checkpointing. <img width="1597" alt="Screenshot 2025-05-27 at 8 52 00 PM" src="https://github.com/user-attachments/assets/2e241baa-16fe-4e87-acce-fb72710babc2" />
Improve the code based on the pytorch#1230 (comment)
Summary: When the debug mode is turned on, 1) warn_tensor_cycles() will be called on rank0 and 2) gc.collect() will be called every iteration to understand the possible memory (tensor) leakage. Reference: https://pytorch.org/blog/understanding-gpu-memory-2/ The current TorchTitan shows memory leakage: ``` CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=20 ``` ``` [rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - step: 1 loss: 12.2721 memory: 42.16GiB(44.38%) tps: 1,677 tflops: 97.12 mfu: 9.82% [rank0]:[titan] 2025-05-27 20:44:18,824 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-05-27 20:44:33,456 - root - INFO - step: 10 loss: 10.0301 memory: 68.72GiB(72.34%) tps: 5,040 tflops: 291.86 mfu: 29.51% [rank0]:[titan] 2025-05-27 20:44:48,141 - root - INFO - step: 20 loss: 8.4547 memory: 90.29GiB(95.03%) tps: 5,579 tflops: 323.12 mfu: 32.67% [rank0]:[titan] 2025-05-27 20:44:48,150 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-05-27 20:44:50,152 - root - INFO - Training completed [rank0]:[titan] 2025-05-27 20:44:50,569 - root - INFO - Process group destroyed. ``` With this PR, we can use the following command to debug ``` CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=20 --training.gc_debug ``` ``` [rank0]:[titan] 2025-05-27 20:49:04,858 - root - INFO - Force GC to perform collection to get the debug information. [rank0]:[rank0]:W0527 20:49:05.414000 2423031 torch/utils/viz/_cycles.py:498] Reference cycle includes a CUDA Tensor see visualization of cycle /tmp/tmp9oginps8.html [rank0]:[rank0]:W0527 20:49:05.687000 2423031 torch/utils/viz/_cycles.py:59] CUDA Memory changed during GC, 2147483648 bytes freed. [rank0]:[titan] 2025-05-27 20:49:07,157 - root - INFO - step: 20 loss: 8.3943 memory: 49.66GiB(52.27%) tps: 3,573 tflops: 206.93 mfu: 20.92% [rank0]:[titan] 2025-05-27 20:49:07,167 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-05-27 20:49:09,169 - root - INFO - Training completed [rank0]:[titan] 2025-05-27 20:49:10,198 - root - INFO - Process group destroyed. ``` `warn_tensor_cycles()` shows that 1) there are reference cycles that include CUDA tensors, 2) 2GB GPU memory is freed when `gc.collect()` is called. And the visualization shows the reference cycle seems to be from activation checkpointing. <img width="1597" alt="Screenshot 2025-05-27 at 8 52 00 PM" src="https://github.com/user-attachments/assets/2e241baa-16fe-4e87-acce-fb72710babc2" />
Improve the code based on the pytorch#1230 (comment)
Summary:
When the debug mode is turned on, 1) warn_tensor_cycles() will be called on rank0 and 2) gc.collect() will be called every iteration to understand the possible memory (tensor) leakage.
Reference: https://pytorch.org/blog/understanding-gpu-memory-2/
The current TorchTitan shows memory leakage:
With this PR, we can use the following command to debug
warn_tensor_cycles()shows that 1) there are reference cycles that include CUDA tensors, 2) 2GB GPU memory is freed whengc.collect()is called. And the visualization shows the reference cycle seems to be from activation checkpointing.