-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Compiled img2imgControlNet with dynamo cannot be loaded from pretrained folder after compile/save
maybe also bug: if you compile unet/controlnet together or in reverse order -it doesn't work too.
Reproduction
from diffusers import StableDiffusionControlNetPipeline, StableDiffusionPipeline, ControlNetModel, UniPCMultistepScheduler, StableDiffusionControlNetImg2ImgPipeline, EulerAncestralDiscreteScheduler
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation, pipeline
import torch
import numpy as np
from controlnet_aux import OpenposeDetector
from diffusers.utils import load_image
from diffusers.utils import load_image
from PIL import Image
import cv2
import numpy as np
from diffusers.utils import load_image
from controlnet_aux import MLSDdetector
from tqdm import tqdm
from time import time
time_start = time()
image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
depth_estimator = pipeline('depth-estimation')
mlsd = MLSDdetector.from_pretrained('lllyasviel/ControlNet')
image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
controlnet_depth = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16).to('cuda')
controlnet_mlsd = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-mlsd", torch_dtype=torch.float16).to('cuda')
controlnet_canny = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16).to('cuda')
pipe_txt2img = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
extract_ema=True,
torch_dtype=torch.float16).to('cuda')
pipe = StableDiffusionControlNetImg2ImgPipeline( **pipe_txt2img.components, controlnet=[
controlnet_depth,
controlnet_mlsd,
controlnet_canny
]).to('cuda')
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
base_image = load_image(
"http://images.cocodataset.org/val2017/000000039769.jpg"
)
base_image = base_image.resize((512, 512))
#canny
image_cn = np.array(base_image)
low_threshold = 100
high_threshold = 200
image_cn = cv2.Canny(image_cn, low_threshold, high_threshold)
image_cn = image_cn[:, :, None]
image_cn = np.concatenate([image_cn, image_cn, image_cn], axis=2)
control_image_cn = Image.fromarray(image_cn)
#mlsd
control_image_mlsd = mlsd(base_image, detect_resolution=min(base_image.size), image_resolution=min(base_image.size)).resize(base_image.size)
print(control_image_mlsd.size)
#depth
depth_estimator = pipeline('depth-estimation')
image = depth_estimator(base_image)['depth']
image = np.array(image)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
control_image_depth = Image.fromarray(image)
print(control_image_depth.size)
prompt = "some artwork"
negative_prompt = "signature, soft, blurry, drawing, sketch, poor quality, ugly, text, type, word, logo, pixelated, low resolution, saturated, high contrast, oversharpened"
guidance_scale=10
num_inference_steps=20
images = [
control_image_depth,
control_image_mlsd,
control_image_cn
]
# pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
for i in range(1):
image = pipe(
prompt,
negative_prompt=negative_prompt,
image=base_image,
control_image=images,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
strength=0.5
).images[0]
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
# pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
for i in range(1):
image = pipe(
prompt,
negative_prompt=negative_prompt,
image=base_image,
control_image=images,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
strength=0.5
).images[0]
time_end = time()
print(f"Time: {time_end - time_start} seconds")
print("Saving model")
pipe.save_pretrained("test_model")
del pipe
import gc
import torch
gc.collect()
torch.cuda.empty_cache()
print("Loading model")
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained("test_model").to('cuda')
print("Model loaded")
Logs
...
Saving model
Loading model
Loading pipeline components...: 0%| | 0/8 [00:00<?, ?it/s]
Traceback (most recent call last):
File "test_saveload.py", line 130, in <module>
print("Model loaded")
File "/home/oleksandr/miniconda3/envs/fantastic/lib/python3.8/site-packages/diffusers/pipelines/pipeline_utils.py", line 1249, in from_pretrained
loaded_sub_model = load_sub_model(
File "/home/oleksandr/miniconda3/envs/fantastic/lib/python3.8/site-packages/diffusers/pipelines/pipeline_utils.py", line 414, in load_sub_model
class_obj, class_candidates = get_class_obj_and_candidates(
File "/home/oleksandr/miniconda3/envs/fantastic/lib/python3.8/site-packages/diffusers/pipelines/pipeline_utils.py", line 327, in get_class_obj_and_candidates
class_obj = getattr(library, class_name)
File "/home/oleksandr/miniconda3/envs/fantastic/lib/python3.8/site-packages/torch/__init__.py", line 1833, in __getattr__
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
AttributeError: module 'torch' has no attribute 'OptimizedModule'
System Info
diffusers
version: 0.22.0.dev0- Platform: Linux-6.2.0-36-generic-x86_64-with-glibc2.17
- Python version: 3.8.16
- PyTorch version (GPU?): 2.1.0 (True)
- Huggingface_hub version: 0.17.2
- Transformers version: 4.30.2
- Accelerate version: 0.21.0
- xFormers version: 0.0.22.post7
Who can help?
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working