Skip to content

Commit f4f8541

Browse files
grad checkpointing (huggingface#4474)
* grad checkpointing * fix make fix-copies * fix --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent e1b5b8b commit f4f8541

File tree

2 files changed

+59
-83
lines changed

2 files changed

+59
-83
lines changed

src/diffusers/models/unet_2d_blocks.py

Lines changed: 41 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -648,16 +648,13 @@ def custom_forward(*inputs):
648648
return custom_forward
649649

650650
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
651-
hidden_states = torch.utils.checkpoint.checkpoint(
652-
create_custom_forward(attn, return_dict=False),
651+
hidden_states = attn(
653652
hidden_states,
654-
encoder_hidden_states,
655-
None, # timestep
656-
None, # class_labels
657-
cross_attention_kwargs,
658-
attention_mask,
659-
encoder_attention_mask,
660-
**ckpt_kwargs,
653+
encoder_hidden_states=encoder_hidden_states,
654+
cross_attention_kwargs=cross_attention_kwargs,
655+
attention_mask=attention_mask,
656+
encoder_attention_mask=encoder_attention_mask,
657+
return_dict=False,
661658
)[0]
662659
hidden_states = torch.utils.checkpoint.checkpoint(
663660
create_custom_forward(resnet),
@@ -1035,16 +1032,13 @@ def custom_forward(*inputs):
10351032
temb,
10361033
**ckpt_kwargs,
10371034
)
1038-
hidden_states = torch.utils.checkpoint.checkpoint(
1039-
create_custom_forward(attn, return_dict=False),
1035+
hidden_states = attn(
10401036
hidden_states,
1041-
encoder_hidden_states,
1042-
None, # timestep
1043-
None, # class_labels
1044-
cross_attention_kwargs,
1045-
attention_mask,
1046-
encoder_attention_mask,
1047-
**ckpt_kwargs,
1037+
encoder_hidden_states=encoder_hidden_states,
1038+
cross_attention_kwargs=cross_attention_kwargs,
1039+
attention_mask=attention_mask,
1040+
encoder_attention_mask=encoder_attention_mask,
1041+
return_dict=False,
10481042
)[0]
10491043
else:
10501044
hidden_states = resnet(hidden_states, temb)
@@ -1711,13 +1705,12 @@ def custom_forward(*inputs):
17111705
return custom_forward
17121706

17131707
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1714-
hidden_states = torch.utils.checkpoint.checkpoint(
1715-
create_custom_forward(attn, return_dict=False),
1708+
hidden_states = attn(
17161709
hidden_states,
1717-
encoder_hidden_states,
1718-
mask,
1719-
cross_attention_kwargs,
1720-
)[0]
1710+
encoder_hidden_states=encoder_hidden_states,
1711+
attention_mask=mask,
1712+
**cross_attention_kwargs,
1713+
)
17211714
else:
17221715
hidden_states = resnet(hidden_states, temb)
17231716

@@ -1912,15 +1905,13 @@ def custom_forward(*inputs):
19121905
temb,
19131906
**ckpt_kwargs,
19141907
)
1915-
hidden_states = torch.utils.checkpoint.checkpoint(
1916-
create_custom_forward(attn, return_dict=False),
1908+
hidden_states = attn(
19171909
hidden_states,
1918-
encoder_hidden_states,
1919-
temb,
1920-
attention_mask,
1921-
cross_attention_kwargs,
1922-
encoder_attention_mask,
1923-
**ckpt_kwargs,
1910+
encoder_hidden_states=encoder_hidden_states,
1911+
emb=temb,
1912+
attention_mask=attention_mask,
1913+
cross_attention_kwargs=cross_attention_kwargs,
1914+
encoder_attention_mask=encoder_attention_mask,
19241915
)
19251916
else:
19261917
hidden_states = resnet(hidden_states, temb)
@@ -2173,16 +2164,13 @@ def custom_forward(*inputs):
21732164
temb,
21742165
**ckpt_kwargs,
21752166
)
2176-
hidden_states = torch.utils.checkpoint.checkpoint(
2177-
create_custom_forward(attn, return_dict=False),
2167+
hidden_states = attn(
21782168
hidden_states,
2179-
encoder_hidden_states,
2180-
None, # timestep
2181-
None, # class_labels
2182-
cross_attention_kwargs,
2183-
attention_mask,
2184-
encoder_attention_mask,
2185-
**ckpt_kwargs,
2169+
encoder_hidden_states=encoder_hidden_states,
2170+
cross_attention_kwargs=cross_attention_kwargs,
2171+
attention_mask=attention_mask,
2172+
encoder_attention_mask=encoder_attention_mask,
2173+
return_dict=False,
21862174
)[0]
21872175
else:
21882176
hidden_states = resnet(hidden_states, temb)
@@ -2872,13 +2860,12 @@ def custom_forward(*inputs):
28722860
return custom_forward
28732861

28742862
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
2875-
hidden_states = torch.utils.checkpoint.checkpoint(
2876-
create_custom_forward(attn, return_dict=False),
2863+
hidden_states = attn(
28772864
hidden_states,
2878-
encoder_hidden_states,
2879-
mask,
2880-
cross_attention_kwargs,
2881-
)[0]
2865+
encoder_hidden_states=encoder_hidden_states,
2866+
attention_mask=mask,
2867+
**cross_attention_kwargs,
2868+
)
28822869
else:
28832870
hidden_states = resnet(hidden_states, temb)
28842871

@@ -3094,16 +3081,14 @@ def custom_forward(*inputs):
30943081
temb,
30953082
**ckpt_kwargs,
30963083
)
3097-
hidden_states = torch.utils.checkpoint.checkpoint(
3098-
create_custom_forward(attn, return_dict=False),
3084+
hidden_states = attn(
30993085
hidden_states,
3100-
encoder_hidden_states,
3101-
temb,
3102-
attention_mask,
3103-
cross_attention_kwargs,
3104-
encoder_attention_mask,
3105-
**ckpt_kwargs,
3106-
)[0]
3086+
encoder_hidden_states=encoder_hidden_states,
3087+
emb=temb,
3088+
attention_mask=attention_mask,
3089+
cross_attention_kwargs=cross_attention_kwargs,
3090+
encoder_attention_mask=encoder_attention_mask,
3091+
)
31073092
else:
31083093
hidden_states = resnet(hidden_states, temb)
31093094
hidden_states = attn(

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,16 +1429,13 @@ def custom_forward(*inputs):
14291429
temb,
14301430
**ckpt_kwargs,
14311431
)
1432-
hidden_states = torch.utils.checkpoint.checkpoint(
1433-
create_custom_forward(attn, return_dict=False),
1432+
hidden_states = attn(
14341433
hidden_states,
1435-
encoder_hidden_states,
1436-
None, # timestep
1437-
None, # class_labels
1438-
cross_attention_kwargs,
1439-
attention_mask,
1440-
encoder_attention_mask,
1441-
**ckpt_kwargs,
1434+
encoder_hidden_states=encoder_hidden_states,
1435+
cross_attention_kwargs=cross_attention_kwargs,
1436+
attention_mask=attention_mask,
1437+
encoder_attention_mask=encoder_attention_mask,
1438+
return_dict=False,
14421439
)[0]
14431440
else:
14441441
hidden_states = resnet(hidden_states, temb)
@@ -1668,16 +1665,13 @@ def custom_forward(*inputs):
16681665
temb,
16691666
**ckpt_kwargs,
16701667
)
1671-
hidden_states = torch.utils.checkpoint.checkpoint(
1672-
create_custom_forward(attn, return_dict=False),
1668+
hidden_states = attn(
16731669
hidden_states,
1674-
encoder_hidden_states,
1675-
None, # timestep
1676-
None, # class_labels
1677-
cross_attention_kwargs,
1678-
attention_mask,
1679-
encoder_attention_mask,
1680-
**ckpt_kwargs,
1670+
encoder_hidden_states=encoder_hidden_states,
1671+
cross_attention_kwargs=cross_attention_kwargs,
1672+
attention_mask=attention_mask,
1673+
encoder_attention_mask=encoder_attention_mask,
1674+
return_dict=False,
16811675
)[0]
16821676
else:
16831677
hidden_states = resnet(hidden_states, temb)
@@ -1809,16 +1803,13 @@ def custom_forward(*inputs):
18091803
return custom_forward
18101804

18111805
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1812-
hidden_states = torch.utils.checkpoint.checkpoint(
1813-
create_custom_forward(attn, return_dict=False),
1806+
hidden_states = attn(
18141807
hidden_states,
1815-
encoder_hidden_states,
1816-
None, # timestep
1817-
None, # class_labels
1818-
cross_attention_kwargs,
1819-
attention_mask,
1820-
encoder_attention_mask,
1821-
**ckpt_kwargs,
1808+
encoder_hidden_states=encoder_hidden_states,
1809+
cross_attention_kwargs=cross_attention_kwargs,
1810+
attention_mask=attention_mask,
1811+
encoder_attention_mask=encoder_attention_mask,
1812+
return_dict=False,
18221813
)[0]
18231814
hidden_states = torch.utils.checkpoint.checkpoint(
18241815
create_custom_forward(resnet),

0 commit comments

Comments
 (0)