-
Notifications
You must be signed in to change notification settings - Fork 571
[PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP with checkpoint flag #2311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f335cc7
e349f46
aa18e74
8f50f4a
f6f034b
00841c2
955f068
5e47706
9a69a6c
ea8270d
f896579
9ee2df8
05d3908
435fe9c
903f37e
0a31a70
418dce6
31cdd9d
fae6052
a6a927e
16b816b
f623124
ff6f58f
8cbdb91
c46ad4c
b068c5f
f0670ed
5a34186
e12fa7c
9b29e49
cc52db5
e94ef33
ebd2329
212fadb
402e5f9
483bbf6
643a3c8
0d0255f
d86bc00
9aaa1b9
07ff0c1
8dec0fc
78b3437
2ec3f18
b959044
197ef5e
332c5c6
b509cec
7e6648b
fe88ceb
1b3ff5f
5fd59e1
26682b4
80a7229
ef21ac6
4a09a0f
e21fe22
906ca4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Signed-off-by: Jaime Cardenas <[email protected]>
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| import time | ||
| import torch | ||
|
||
| from transformer_engine.pytorch import SelectiveLayerNormMLP | ||
| from transformer_engine.pytorch import LayerNormMLP | ||
| from collections import defaultdict | ||
|
|
||
| torch.manual_seed(1234) | ||
|
|
@@ -32,10 +32,10 @@ def build(self): | |
|
|
||
| ln_list, sln_list = [], [] | ||
| for _ in range(self._layers): | ||
| ln = SelectiveLayerNormMLP( | ||
| ln = LayerNormMLP( | ||
| self._hidden_size, self._ffn_hidden_size, checkpoint=False | ||
| ).to(device) | ||
| sln = SelectiveLayerNormMLP( | ||
| sln = LayerNormMLP( | ||
| self._hidden_size, self._ffn_hidden_size, checkpoint=True | ||
| ).to(device) | ||
| with torch.no_grad(): | ||
|
|
@@ -180,7 +180,7 @@ def _run_bwd(model, out): | |
| self.stats[desc]["diff"][key] = self._max_diff(ln_grads[key], sln_grads[key]) | ||
|
|
||
| def summarize(self): | ||
| _modules = [("ln_stats", "LayerNormMLP"), ("sln_stats", "SelectiveLayerNormMLP")] | ||
| _modules = [("ln_stats", "No Checkpointing"), ("sln_stats", "Checkpointing")] | ||
| _metric_map = {"time": (1, "ms"), "mem": (1e-6, "MB")} | ||
|
|
||
| left_w = 18 # "fwd time" / "bwd mem" label | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.