Skip to content

Compiled with dynamo (torch.compile) img2imgControlNet pipeline not able to be loaded after save_pretrained #5617

@Alexadar

Description

@Alexadar

Describe the bug

Compiled img2imgControlNet with dynamo cannot be loaded from pretrained folder after compile/save

maybe also bug: if you compile unet/controlnet together or in reverse order -it doesn't work too.

Reproduction

from diffusers import StableDiffusionControlNetPipeline, StableDiffusionPipeline, ControlNetModel, UniPCMultistepScheduler, StableDiffusionControlNetImg2ImgPipeline, EulerAncestralDiscreteScheduler
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation, pipeline
import torch
import numpy as np
from controlnet_aux import OpenposeDetector
from diffusers.utils import load_image
from diffusers.utils import load_image
from PIL import Image
import cv2
import numpy as np
from diffusers.utils import load_image
from controlnet_aux import MLSDdetector
from tqdm import tqdm
from time import time

time_start = time()
image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
depth_estimator = pipeline('depth-estimation')
mlsd = MLSDdetector.from_pretrained('lllyasviel/ControlNet')

image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
controlnet_depth = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16).to('cuda')
controlnet_mlsd = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-mlsd", torch_dtype=torch.float16).to('cuda')
controlnet_canny = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16).to('cuda')

pipe_txt2img = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    extract_ema=True,
    torch_dtype=torch.float16).to('cuda')

pipe = StableDiffusionControlNetImg2ImgPipeline( **pipe_txt2img.components, controlnet=[
        controlnet_depth,
        controlnet_mlsd,
        controlnet_canny
    ]).to('cuda')
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)

base_image = load_image(
    "http://images.cocodataset.org/val2017/000000039769.jpg"
)
base_image = base_image.resize((512, 512))

#canny
image_cn = np.array(base_image)

low_threshold = 100
high_threshold = 200

image_cn = cv2.Canny(image_cn, low_threshold, high_threshold)
image_cn = image_cn[:, :, None]
image_cn = np.concatenate([image_cn, image_cn, image_cn], axis=2)
control_image_cn = Image.fromarray(image_cn)

#mlsd
control_image_mlsd = mlsd(base_image, detect_resolution=min(base_image.size), image_resolution=min(base_image.size)).resize(base_image.size)
print(control_image_mlsd.size)

#depth
depth_estimator = pipeline('depth-estimation')
image = depth_estimator(base_image)['depth']
image = np.array(image)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
control_image_depth = Image.fromarray(image)
print(control_image_depth.size)

prompt = "some artwork"
negative_prompt = "signature, soft, blurry, drawing, sketch, poor quality, ugly, text, type, word, logo, pixelated, low resolution, saturated, high contrast, oversharpened"
guidance_scale=10
num_inference_steps=20

images = [
    control_image_depth,
    control_image_mlsd,
    control_image_cn
]

# pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)


for i in range(1):
    image = pipe(
        prompt,
        negative_prompt=negative_prompt,
        image=base_image,
        control_image=images,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        strength=0.5
    ).images[0]

pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
# pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)


for i in range(1):
    image = pipe(
        prompt,
        negative_prompt=negative_prompt,
        image=base_image,
        control_image=images,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        strength=0.5
    ).images[0]

time_end = time()
print(f"Time: {time_end - time_start} seconds")

print("Saving model")
pipe.save_pretrained("test_model")
del pipe

import gc
import torch
gc.collect()
torch.cuda.empty_cache()

print("Loading model")

pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained("test_model").to('cuda')

print("Model loaded")

Logs

...
Saving model
Loading model
Loading pipeline components...:   0%|                       | 0/8 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "test_saveload.py", line 130, in <module>
    print("Model loaded")
  File "/home/oleksandr/miniconda3/envs/fantastic/lib/python3.8/site-packages/diffusers/pipelines/pipeline_utils.py", line 1249, in from_pretrained
    loaded_sub_model = load_sub_model(
  File "/home/oleksandr/miniconda3/envs/fantastic/lib/python3.8/site-packages/diffusers/pipelines/pipeline_utils.py", line 414, in load_sub_model
    class_obj, class_candidates = get_class_obj_and_candidates(
  File "/home/oleksandr/miniconda3/envs/fantastic/lib/python3.8/site-packages/diffusers/pipelines/pipeline_utils.py", line 327, in get_class_obj_and_candidates
    class_obj = getattr(library, class_name)
  File "/home/oleksandr/miniconda3/envs/fantastic/lib/python3.8/site-packages/torch/__init__.py", line 1833, in __getattr__
    raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
AttributeError: module 'torch' has no attribute 'OptimizedModule'

System Info

  • diffusers version: 0.22.0.dev0
  • Platform: Linux-6.2.0-36-generic-x86_64-with-glibc2.17
  • Python version: 3.8.16
  • PyTorch version (GPU?): 2.1.0 (True)
  • Huggingface_hub version: 0.17.2
  • Transformers version: 4.30.2
  • Accelerate version: 0.21.0
  • xFormers version: 0.0.22.post7

Who can help?

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions