-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Pass device to enable_model_cpu_offload in maybe_free_model_hooks #6937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Can you provide a code snippet? |
|
||
# make sure the model is in the same state as before calling it | ||
self.enable_model_cpu_offload() | ||
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally I think it's a non-breaking change since we default to "cuda" for device
in enable_model_cpu_offload()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's ok too :)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
import torch
import intel_extension_for_pytorch
from diffusers import AutoPipelineForText2Image, AutoencoderKL, EulerAncestralDiscreteScheduler
model="cagliostrolab/animagine-xl-3.0"
vae="madebyollin/sdxl-vae-fp16-fix"
prompt="masterpiece, best quality, newest, 1girl, solo, depth of field, rim lighting, flowers, petals, crystals, butterfly, scenery, upper body, dark red hair, straight hair, long hair, blue eyes, cat ears, mature female, white sweater, blush, slight smile,"
negative_prompt = "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name"
seed=123456789
num_inference_steps=20
pipeline = AutoPipelineForText2Image.from_pretrained(model, vae=AutoencoderKL.from_pretrained(vae))
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
pipeline.safety_checker = None
pipeline = pipeline.to(torch.bfloat16)
pipeline.enable_model_cpu_offload(device="xpu")
# Will work
image = pipeline(prompt, negative_prompt=negative_prompt, width=1080, height=1080, seed=seed, num_inference_steps=num_inference_steps).images[0]
image.save("image.jpg")
# Will fail
image = pipeline(prompt, negative_prompt=negative_prompt, width=1080, height=1080, seed=seed, num_inference_steps=num_inference_steps).images[0]
image.save("image2.jpg") The code above will fail with these logs:
Works as expected with this PR:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
|
||
# make sure the model is in the same state as before calling it | ||
self.enable_model_cpu_offload() | ||
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's ok too :)
What does this PR do?
Fixes # Non CUDA devices with
enable_model_cpu_offload
Target is DirectML and IPEX / XPU devices.
When re-enabling enable_model_cpu_offload in maybe_free_model_hooks, device is not used and enable_model_cpu_offload defaults to cuda.
This PR adds
self._offload_device
and passes it to enable_model_cpu_offload in maybe_free_model_hooks.self._offload_device
is also being set inenable_sequential_cpu_offload
for compatibility reasons.Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in this PR.
@patrickvonplaten and @sayakpaul