diff --git a/docs/source/en/using-diffusers/ip_adapter.md b/docs/source/en/using-diffusers/ip_adapter.md index b37ef15fc6af..7b5b4fd3111f 100644 --- a/docs/source/en/using-diffusers/ip_adapter.md +++ b/docs/source/en/using-diffusers/ip_adapter.md @@ -231,6 +231,9 @@ export_to_gif(frames, "gummy_bear.gif") +> [!TIP] +> While calling `load_ip_adapter()`, pass `low_cpu_mem_usage=True` to speed up the loading time. + ## Specific use cases IP-Adapter's image prompting and compatibility with other adapters and models makes it a versatile tool for a variety of use cases. This section covers some of the more popular applications of IP-Adapter, and we can't wait to see what you come up with! diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index bf3c419261c4..d6bf50e51eee 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -19,8 +19,11 @@ from huggingface_hub.utils import validate_hf_hub_args from safetensors import safe_open +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..utils import ( _get_model_file, + is_accelerate_available, + is_torch_version, is_transformers_available, logging, ) @@ -86,6 +89,11 @@ def load_ip_adapter( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. """ # handle the list inputs for multiple IP Adapters @@ -116,6 +124,22 @@ def load_ip_adapter( local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) user_agent = { "file_type": "attn_procs_weights", @@ -165,6 +189,7 @@ def load_ip_adapter( image_encoder = CLIPVisionModelWithProjection.from_pretrained( pretrained_model_name_or_path_or_dict, subfolder=Path(subfolder, "image_encoder").as_posix(), + low_cpu_mem_usage=low_cpu_mem_usage, ).to(self.device, dtype=self.dtype) self.register_modules(image_encoder=image_encoder) else: @@ -175,9 +200,9 @@ def load_ip_adapter( feature_extractor = CLIPImageProcessor() self.register_modules(feature_extractor=feature_extractor) - # load ip-adapter into unet + # load ip-adapter into unet unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet._load_ip_adapter_weights(state_dicts) + unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) def set_ip_adapter_scale(self, scale): """ diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 4d3d672d38ae..9d8e2666c518 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -37,6 +37,7 @@ _get_model_file, delete_adapter_layers, is_accelerate_available, + is_torch_version, logging, set_adapter_layers, set_weights_and_activate_adapters, @@ -168,15 +169,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict "framework": "pytorch", } - if low_cpu_mem_usage and not is_accelerate_available(): - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) - model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights @@ -694,9 +686,29 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) - def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): + def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False): + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + else: + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + updated_state_dict = {} image_projection = None + init_context = init_empty_weights if low_cpu_mem_usage else nullcontext if "proj.weight" in state_dict: # IP-Adapter @@ -704,11 +716,12 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): clip_embeddings_dim = state_dict["proj.weight"].shape[-1] cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 - image_projection = ImageProjection( - cross_attention_dim=cross_attention_dim, - image_embed_dim=clip_embeddings_dim, - num_image_text_embeds=num_image_text_embeds, - ) + with init_context(): + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) for key, value in state_dict.items(): diffusers_name = key.replace("proj", "image_embeds") @@ -719,9 +732,10 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] cross_attention_dim = state_dict["proj.3.weight"].shape[0] - image_projection = IPAdapterFullImageProjection( - cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim - ) + with init_context(): + image_projection = IPAdapterFullImageProjection( + cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim + ) for key, value in state_dict.items(): diffusers_name = key.replace("proj.0", "ff.net.0.proj") @@ -737,13 +751,14 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): hidden_dims = state_dict["latents"].shape[2] heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64 - image_projection = IPAdapterPlusImageProjection( - embed_dims=embed_dims, - output_dims=output_dims, - hidden_dims=hidden_dims, - heads=heads, - num_queries=num_image_text_embeds, - ) + with init_context(): + image_projection = IPAdapterPlusImageProjection( + embed_dims=embed_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + heads=heads, + num_queries=num_image_text_embeds, + ) for key, value in state_dict.items(): diffusers_name = key.replace("0.to", "2.to") @@ -765,10 +780,14 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): else: updated_state_dict[diffusers_name] = value - image_projection.load_state_dict(updated_state_dict) + if not low_cpu_mem_usage: + image_projection.load_state_dict(updated_state_dict) + else: + load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) + return image_projection - def _convert_ip_adapter_attn_to_diffusers(self, state_dicts): + def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False): from ..models.attention_processor import ( AttnProcessor, AttnProcessor2_0, @@ -776,9 +795,29 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts): IPAdapterAttnProcessor2_0, ) + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + else: + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + # set ip-adapter cross-attention processors & load state_dict attn_procs = {} key_id = 1 + init_context = init_empty_weights if low_cpu_mem_usage else nullcontext for name in self.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim if name.startswith("mid_block"): @@ -811,39 +850,49 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts): # IP-Adapter Plus num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] - attn_procs[name] = attn_processor_class( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - scale=1.0, - num_tokens=num_image_text_embeds, - ).to(dtype=self.dtype, device=self.device) + with init_context(): + attn_procs[name] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=num_image_text_embeds, + ) value_dict = {} for i, state_dict in enumerate(state_dicts): value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) - attn_procs[name].load_state_dict(value_dict) + if not low_cpu_mem_usage: + attn_procs[name].load_state_dict(value_dict) + else: + device = next(iter(value_dict.values())).device + dtype = next(iter(value_dict.values())).dtype + load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) + key_id += 2 return attn_procs - def _load_ip_adapter_weights(self, state_dicts): + def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): if not isinstance(state_dicts, list): state_dicts = [state_dicts] # Set encoder_hid_proj after loading ip_adapter weights, # because `IPAdapterPlusImageProjection` also has `attn_processors`. self.encoder_hid_proj = None - attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts) + attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) self.set_attn_processor(attn_procs) # convert IP-Adapter Image Projection layers to diffusers image_projection_layers = [] for state_dict in state_dicts: - image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) - image_projection_layer.to(device=self.device, dtype=self.dtype) + image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers( + state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage + ) image_projection_layers.append(image_projection_layer) self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) self.config.encoder_hid_dim_type = "ip_image_proj" + + self.to(dtype=self.dtype, device=self.device)