-
Notifications
You must be signed in to change notification settings - Fork 540
[PyTorch] Implement Selective Activation Checkpointing for LayerNormMLP with checkpoint flag #2311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
f335cc7
e349f46
aa18e74
8f50f4a
f6f034b
00841c2
955f068
5e47706
9a69a6c
ea8270d
f896579
9ee2df8
05d3908
435fe9c
903f37e
0a31a70
418dce6
31cdd9d
fae6052
a6a927e
16b816b
f623124
ff6f58f
8cbdb91
c46ad4c
b068c5f
f0670ed
5a34186
e12fa7c
9b29e49
cc52db5
e94ef33
ebd2329
212fadb
402e5f9
483bbf6
643a3c8
0d0255f
d86bc00
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
for more information, see https://pre-commit.ci Signed-off-by: Jaime Cardenas <[email protected]>
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,3 @@ | ||
|
|
||
| import time | ||
|
|
||
| import torch | ||
|
||
|
|
@@ -7,6 +6,7 @@ | |
| torch.manual_seed(1234) | ||
| device = torch.device("cuda") | ||
|
|
||
|
|
||
| class _Sequential(torch.nn.Sequential): | ||
| """Sequential model that forwards keyword arguments to modules""" | ||
|
|
||
|
|
@@ -16,10 +16,11 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: | |
| x = module(x, **kwargs) | ||
| return x | ||
|
|
||
|
|
||
| class ModelConfig: | ||
| def __init__( | ||
| self, | ||
| hidden_size: int = 128, | ||
| self, | ||
| hidden_size: int = 128, | ||
| ffn_hidden_size: int = 512, | ||
| layers: int = 1, | ||
| ): | ||
|
|
@@ -48,14 +49,16 @@ def build(self): | |
|
|
||
| return ln_model, sln_model | ||
|
|
||
|
|
||
| config = { | ||
| # "small": ModelConfig(128, 512, 12), | ||
| # "medium": ModelConfig(512, 2048, 12), | ||
| # "large": ModelConfig(1024, 4096, 12), | ||
| "huge": ModelConfig(2048, 8192, 12), | ||
| } | ||
|
|
||
| data_sizes = [2**7, 2**10, 2**14, 2**16]#2**18] | ||
| data_sizes = [2**7, 2**10, 2**14, 2**16] # 2**18] | ||
|
|
||
|
|
||
| class Profiler: | ||
| def __init__(self): | ||
|
|
@@ -68,7 +71,7 @@ def __init__(self): | |
| "bwd_stats": { | ||
| "mem": [], | ||
| "time": [], | ||
| } | ||
| }, | ||
| }, | ||
| "sln_stats": { | ||
| "fwd_stats": { | ||
|
|
@@ -78,7 +81,7 @@ def __init__(self): | |
| "bwd_stats": { | ||
| "mem": [], | ||
| "time": [], | ||
| } | ||
| }, | ||
| }, | ||
| "diff": { | ||
| "out": [], | ||
|
|
@@ -88,8 +91,7 @@ def __init__(self): | |
| "fc1_bias": [], | ||
| "fc2_weight": [], | ||
| "fc2_bias": [], | ||
| } | ||
|
|
||
| }, | ||
| } | ||
|
|
||
| def compare(self, ln_model, sln_model, data): | ||
|
|
@@ -161,11 +163,19 @@ def _run_bwd(model, out): | |
| self.stats["sln_stats"]["bwd_stats"]["time"].append(sln_bwd_time) | ||
| self.stats["sln_stats"]["bwd_stats"]["mem"].append(sln_bwd_mem) | ||
|
|
||
| for key in ["layer_norm_weight", "layer_norm_bias", "fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias"]: | ||
| for key in [ | ||
| "layer_norm_weight", | ||
| "layer_norm_bias", | ||
| "fc1_weight", | ||
| "fc1_bias", | ||
| "fc2_weight", | ||
| "fc2_bias", | ||
| ]: | ||
| self.stats["diff"][key].append(self._max_diff(ln_grads[key], sln_grads[key])) | ||
|
|
||
| def summarize(self): | ||
| """Print a concise summary of collected statistics.""" | ||
|
|
||
| def _summarize(values): | ||
| if not values: | ||
| return {"avg": 0.0, "min": 0.0, "max": 0.0} | ||
|
|
@@ -202,7 +212,14 @@ def _summarize(values): | |
| print(f"Forward output max diff avg: {summary:.3e}") | ||
|
|
||
| print("Gradient max diff averages:") | ||
| for key in ["layer_norm_weight", "layer_norm_bias", "fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias"]: | ||
| for key in [ | ||
| "layer_norm_weight", | ||
| "layer_norm_bias", | ||
| "fc1_weight", | ||
| "fc1_bias", | ||
| "fc2_weight", | ||
| "fc2_bias", | ||
| ]: | ||
| summary = sum(diff_stats[key]) / len(diff_stats[key]) | ||
| print(f" {key}: {summary:.3e}") | ||
| print() | ||
|
|
@@ -229,6 +246,7 @@ def _collect_param_grads(self, model): | |
| def _param_key(self, name): | ||
| return name.split(".")[-1] | ||
|
|
||
|
|
||
| def main(): | ||
|
|
||
| for size in config: | ||
|
|
@@ -243,8 +261,12 @@ def main(): | |
|
|
||
| profiler.compare(ln_model, sln_model, dummy_data) | ||
|
|
||
| print(f"summarizing comparison for seq={seq_len}, hidden={config[size]._hidden_size}, ffn_fidden={config[size]._ffn_hidden_size}, layers={config[size]._layers}\n") | ||
| print( | ||
| f"summarizing comparison for seq={seq_len}, hidden={config[size]._hidden_size}," | ||
| f" ffn_fidden={config[size]._ffn_hidden_size}, layers={config[size]._layers}\n" | ||
| ) | ||
| profiler.summarize() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -397,4 +397,3 @@ def test_make_graphed_callables_with_fp8_weight_caching( | |
| fp8_recipe=fp8_recipe, | ||
| fp8_weight_caching=True, | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.