-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
DownBlockMotion, UpBlockMotion can be normally called without gradient_checkpointing but cant with it.
something seems weird here:
https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_3d_blocks.py
line 1034~1043
if self.training and self.gradient_checkpointing:
......
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states.requires_grad_(),
temb,
num_frames,
)#has temb here
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] # no temb here
it's strange that motion_module is self-attention during non-checkpointed training,and is cross-attention during checkpointed training.
Reproduction
import diffusers,torch
from diffusers.models.unet_3d_blocks import DownBlockMotion as B
b=B(32,32,32)
b(torch.zeros(1,32,4,4),torch.zeros(1,32))[0].flatten()[:10] # everything ok here
b.gradient_checkpointing=True
b(torch.zeros(1,32,4,4),torch.zeros(1,32))[0].flatten()[:10] # raise ValueError: not enough values to unpack (expected 3, got 2)
Logs
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
in
25 b.gradient_checkpointing=True
26
---> 27 b(torch.zeros(1,32,4,4),torch.zeros(1,32))[0].flatten()[:10]
c:\Users\Administrator\anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
c:\Users\Administrator\anaconda3\lib\site-packages\diffusers\models\unet_3d_blocks.py in forward(self, hidden_states, temb, scale, num_frames)
1032 create_custom_forward(resnet), hidden_states, temb, scale
1033 )
-> 1034 hidden_states = torch.utils.checkpoint.checkpoint(
1035 create_custom_forward(motion_module),
1036 hidden_states.requires_grad_(),
c:\Users\Administrator\anaconda3\lib\site-packages\torch\utils\checkpoint.py in checkpoint(function, use_reentrant, *args, **kwargs)
247
248 if use_reentrant:
--> 249 return CheckpointFunction.apply(function, preserve, *args)
...
-> 1225 batch_size, sequence_length, _ = (
1226 hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1227 )
ValueError: not enough values to unpack (expected 3, got 2)
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...
System Info
diffusers
version: 0.25.1- Platform: Windows-10-10.0.19041-SP0
- Python version: 3.8.8
- PyTorch version (GPU?): 2.0.1+cu117 (True)
- Huggingface_hub version: 0.20.3
- Transformers version: 4.37.0
- Accelerate version: 0.26.1
- xFormers version: not installed
- Using GPU in script?: both are same
- Using distributed or parallel set-up in script?: Nope
Who can help?
Metadata
Metadata
Labels
bugSomething isn't workingSomething isn't working