Skip to content

Commit 9254d1f

Browse files
authored
Pass device to enable_model_cpu_offload in maybe_free_model_hooks (huggingface#6937)
1 parent e1bdcc7 commit 9254d1f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
14231423

14241424
device_type = torch_device.type
14251425
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
1426+
self._offload_device = device
14261427

14271428
if self.device.type != "cpu":
14281429
self.to("cpu", silence_dtype_warnings=True)
@@ -1472,7 +1473,7 @@ def maybe_free_model_hooks(self):
14721473
hook.remove()
14731474

14741475
# make sure the model is in the same state as before calling it
1475-
self.enable_model_cpu_offload()
1476+
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
14761477

14771478
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
14781479
r"""
@@ -1508,6 +1509,7 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
15081509

15091510
device_type = torch_device.type
15101511
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
1512+
self._offload_device = device
15111513

15121514
if self.device.type != "cpu":
15131515
self.to("cpu", silence_dtype_warnings=True)

0 commit comments

Comments
 (0)