Skip to content

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Feb 12, 2024

What does this PR do?

This PR adds low_cpu_mem_usage support in load_ip_adapter() to speed up loading time.

Script to test:

from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torch
import time
import argparse

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")

def run_sd(low_cpu_mem_usage):
    start = time.time()
    pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
    pipeline.load_ip_adapter(
        "h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin", low_cpu_mem_usage=low_cpu_mem_usage
    )
    end = time.time()
    print(f"Loading time -- {(end - start):.3f} seconds")

    _ = pipeline(
        prompt="best quality, high quality", 
        ip_adapter_image=image,
        negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
        num_inference_steps=2,
    )


def run_sdxl(low_cpu_mem_usage):
    start = time.time()
    pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
    pipeline.load_ip_adapter(
        "h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin", low_cpu_mem_usage=low_cpu_mem_usage
    )
    end = time.time()
    print(f"Loading time -- {(end - start):.3f} seconds")
    
    _ = pipeline(
        prompt="best quality, high quality", 
        ip_adapter_image=image,
        negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
        num_inference_steps=2,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--run_sd", action="store_true")
    parser.add_argument("--run_sdxl", action="store_true")
    parser.add_argument("--low_cpu_mem_usage", action="store_true")
    args = parser.parse_args()

    if args.run_sd and args.run_sdxl:
        raise ValueError("Both `run_sd` and `run_sdxl` cannot be True.")

    if not args.run_sd and not args.run_sdxl:
        raise ValueError("Both `run_sd` and `run_sdxl` cannot be False.")

    fn_to_run = run_sd if args.run_sd else run_sdxl
    fn_to_run(low_cpu_mem_usage=args.low_cpu_mem_usage)

On average, passing low_cpu_mem_usage=True in load_ip_adapter() saves about 2-3 seconds.

Will add documentation once the PR is approved.

TODO

  • Documentation

@sayakpaul sayakpaul requested a review from yiyixuxu February 12, 2024 10:22
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! this is great!!!
left a comment, I think we only need to warn once when we actually apply low_cpu_mem_uesage

@sayakpaul
Copy link
Member Author

Documentation has been dealt with as well.

@sayakpaul sayakpaul merged commit e6d1728 into main Feb 15, 2024
@sayakpaul sayakpaul deleted the ip-adapter-low-cpu-mem-usage branch February 15, 2024 10:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants