Skip to content

Commit c6f8c31

Browse files
BaaBaaGoatDN6
andauthored
Fix forward pass in UNetMotionModel when gradient checkpoint is enabled (huggingface#6744)
fix huggingface#6742 Co-authored-by: Dhruv Nair <[email protected]>
1 parent 64909f1 commit c6f8c31

File tree

1 file changed

+10
-21
lines changed

1 file changed

+10
-21
lines changed

src/diffusers/models/unets/unet_3d_blocks.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,16 +1031,10 @@ def custom_forward(*inputs):
10311031
hidden_states = torch.utils.checkpoint.checkpoint(
10321032
create_custom_forward(resnet), hidden_states, temb, scale
10331033
)
1034-
hidden_states = torch.utils.checkpoint.checkpoint(
1035-
create_custom_forward(motion_module),
1036-
hidden_states.requires_grad_(),
1037-
temb,
1038-
num_frames,
1039-
)
10401034

10411035
else:
10421036
hidden_states = resnet(hidden_states, temb, scale=scale)
1043-
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
1037+
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
10441038

10451039
output_states = output_states + (hidden_states,)
10461040

@@ -1221,10 +1215,10 @@ def custom_forward(*inputs):
12211215
encoder_attention_mask=encoder_attention_mask,
12221216
return_dict=False,
12231217
)[0]
1224-
hidden_states = motion_module(
1225-
hidden_states,
1226-
num_frames=num_frames,
1227-
)[0]
1218+
hidden_states = motion_module(
1219+
hidden_states,
1220+
num_frames=num_frames,
1221+
)[0]
12281222

12291223
# apply additional residuals to the output of the last pair of resnet and attention blocks
12301224
if i == len(blocks) - 1 and additional_residuals is not None:
@@ -1425,10 +1419,10 @@ def custom_forward(*inputs):
14251419
encoder_attention_mask=encoder_attention_mask,
14261420
return_dict=False,
14271421
)[0]
1428-
hidden_states = motion_module(
1429-
hidden_states,
1430-
num_frames=num_frames,
1431-
)[0]
1422+
hidden_states = motion_module(
1423+
hidden_states,
1424+
num_frames=num_frames,
1425+
)[0]
14321426

14331427
if self.upsamplers is not None:
14341428
for upsampler in self.upsamplers:
@@ -1563,15 +1557,10 @@ def custom_forward(*inputs):
15631557
hidden_states = torch.utils.checkpoint.checkpoint(
15641558
create_custom_forward(resnet), hidden_states, temb
15651559
)
1566-
hidden_states = torch.utils.checkpoint.checkpoint(
1567-
create_custom_forward(resnet),
1568-
hidden_states,
1569-
temb,
1570-
)
15711560

15721561
else:
15731562
hidden_states = resnet(hidden_states, temb, scale=scale)
1574-
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
1563+
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
15751564

15761565
if self.upsamplers is not None:
15771566
for upsampler in self.upsamplers:

0 commit comments

Comments
 (0)