@@ -1031,16 +1031,10 @@ def custom_forward(*inputs):
1031
1031
hidden_states = torch .utils .checkpoint .checkpoint (
1032
1032
create_custom_forward (resnet ), hidden_states , temb , scale
1033
1033
)
1034
- hidden_states = torch .utils .checkpoint .checkpoint (
1035
- create_custom_forward (motion_module ),
1036
- hidden_states .requires_grad_ (),
1037
- temb ,
1038
- num_frames ,
1039
- )
1040
1034
1041
1035
else :
1042
1036
hidden_states = resnet (hidden_states , temb , scale = scale )
1043
- hidden_states = motion_module (hidden_states , num_frames = num_frames )[0 ]
1037
+ hidden_states = motion_module (hidden_states , num_frames = num_frames )[0 ]
1044
1038
1045
1039
output_states = output_states + (hidden_states ,)
1046
1040
@@ -1221,10 +1215,10 @@ def custom_forward(*inputs):
1221
1215
encoder_attention_mask = encoder_attention_mask ,
1222
1216
return_dict = False ,
1223
1217
)[0 ]
1224
- hidden_states = motion_module (
1225
- hidden_states ,
1226
- num_frames = num_frames ,
1227
- )[0 ]
1218
+ hidden_states = motion_module (
1219
+ hidden_states ,
1220
+ num_frames = num_frames ,
1221
+ )[0 ]
1228
1222
1229
1223
# apply additional residuals to the output of the last pair of resnet and attention blocks
1230
1224
if i == len (blocks ) - 1 and additional_residuals is not None :
@@ -1425,10 +1419,10 @@ def custom_forward(*inputs):
1425
1419
encoder_attention_mask = encoder_attention_mask ,
1426
1420
return_dict = False ,
1427
1421
)[0 ]
1428
- hidden_states = motion_module (
1429
- hidden_states ,
1430
- num_frames = num_frames ,
1431
- )[0 ]
1422
+ hidden_states = motion_module (
1423
+ hidden_states ,
1424
+ num_frames = num_frames ,
1425
+ )[0 ]
1432
1426
1433
1427
if self .upsamplers is not None :
1434
1428
for upsampler in self .upsamplers :
@@ -1563,15 +1557,10 @@ def custom_forward(*inputs):
1563
1557
hidden_states = torch .utils .checkpoint .checkpoint (
1564
1558
create_custom_forward (resnet ), hidden_states , temb
1565
1559
)
1566
- hidden_states = torch .utils .checkpoint .checkpoint (
1567
- create_custom_forward (resnet ),
1568
- hidden_states ,
1569
- temb ,
1570
- )
1571
1560
1572
1561
else :
1573
1562
hidden_states = resnet (hidden_states , temb , scale = scale )
1574
- hidden_states = motion_module (hidden_states , num_frames = num_frames )[0 ]
1563
+ hidden_states = motion_module (hidden_states , num_frames = num_frames )[0 ]
1575
1564
1576
1565
if self .upsamplers is not None :
1577
1566
for upsampler in self .upsamplers :
0 commit comments