Skip to content

Commit 8e7bbfb

Browse files
rootonchairsayakpaulyiyixuxu
authored
add padding_mask_crop to all inpaint pipelines (huggingface#6360)
* add padding_mask_crop --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent e2773c6 commit 8e7bbfb

File tree

4 files changed

+154
-14
lines changed

4 files changed

+154
-14
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -683,16 +683,19 @@ def check_inputs(
683683
self,
684684
prompt,
685685
image,
686+
mask_image,
686687
height,
687688
width,
688689
callback_steps,
690+
output_type,
689691
negative_prompt=None,
690692
prompt_embeds=None,
691693
negative_prompt_embeds=None,
692694
controlnet_conditioning_scale=1.0,
693695
control_guidance_start=0.0,
694696
control_guidance_end=1.0,
695697
callback_on_step_end_tensor_inputs=None,
698+
padding_mask_crop=None,
696699
):
697700
if height is not None and height % 8 != 0 or width is not None and width % 8 != 0:
698701
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -736,6 +739,19 @@ def check_inputs(
736739
f" {negative_prompt_embeds.shape}."
737740
)
738741

742+
if padding_mask_crop is not None:
743+
if not isinstance(image, PIL.Image.Image):
744+
raise ValueError(
745+
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
746+
)
747+
if not isinstance(mask_image, PIL.Image.Image):
748+
raise ValueError(
749+
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
750+
f" {type(mask_image)}."
751+
)
752+
if output_type != "pil":
753+
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
754+
739755
# `prompt` needs more sophisticated handling when there are multiple
740756
# conditionings.
741757
if isinstance(self.controlnet, MultiControlNetModel):
@@ -862,7 +878,6 @@ def check_image(self, image, prompt, prompt_embeds):
862878
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
863879
)
864880

865-
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
866881
def prepare_control_image(
867882
self,
868883
image,
@@ -872,10 +887,14 @@ def prepare_control_image(
872887
num_images_per_prompt,
873888
device,
874889
dtype,
890+
crops_coords,
891+
resize_mode,
875892
do_classifier_free_guidance=False,
876893
guess_mode=False,
877894
):
878-
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
895+
image = self.control_image_processor.preprocess(
896+
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
897+
).to(dtype=torch.float32)
879898
image_batch_size = image.shape[0]
880899

881900
if image_batch_size == 1:
@@ -1074,6 +1093,7 @@ def __call__(
10741093
control_image: PipelineImageInput = None,
10751094
height: Optional[int] = None,
10761095
width: Optional[int] = None,
1096+
padding_mask_crop: Optional[int] = None,
10771097
strength: float = 1.0,
10781098
num_inference_steps: int = 50,
10791099
guidance_scale: float = 7.5,
@@ -1130,6 +1150,12 @@ def __call__(
11301150
The height in pixels of the generated image.
11311151
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
11321152
The width in pixels of the generated image.
1153+
padding_mask_crop (`int`, *optional*, defaults to `None`):
1154+
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
1155+
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
1156+
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
1157+
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
1158+
and contain information inreleant for inpainging, such as background.
11331159
strength (`float`, *optional*, defaults to 1.0):
11341160
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
11351161
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
@@ -1240,16 +1266,19 @@ def __call__(
12401266
self.check_inputs(
12411267
prompt,
12421268
control_image,
1269+
mask_image,
12431270
height,
12441271
width,
12451272
callback_steps,
1273+
output_type,
12461274
negative_prompt,
12471275
prompt_embeds,
12481276
negative_prompt_embeds,
12491277
controlnet_conditioning_scale,
12501278
control_guidance_start,
12511279
control_guidance_end,
12521280
callback_on_step_end_tensor_inputs,
1281+
padding_mask_crop,
12531282
)
12541283

12551284
self._guidance_scale = guidance_scale
@@ -1264,6 +1293,14 @@ def __call__(
12641293
else:
12651294
batch_size = prompt_embeds.shape[0]
12661295

1296+
if padding_mask_crop is not None:
1297+
height, width = self.image_processor.get_default_height_width(image, height, width)
1298+
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
1299+
resize_mode = "fill"
1300+
else:
1301+
crops_coords = None
1302+
resize_mode = "default"
1303+
12671304
device = self._execution_device
12681305

12691306
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
@@ -1315,6 +1352,8 @@ def __call__(
13151352
num_images_per_prompt=num_images_per_prompt,
13161353
device=device,
13171354
dtype=controlnet.dtype,
1355+
crops_coords=crops_coords,
1356+
resize_mode=resize_mode,
13181357
do_classifier_free_guidance=self.do_classifier_free_guidance,
13191358
guess_mode=guess_mode,
13201359
)
@@ -1330,6 +1369,8 @@ def __call__(
13301369
num_images_per_prompt=num_images_per_prompt,
13311370
device=device,
13321371
dtype=controlnet.dtype,
1372+
crops_coords=crops_coords,
1373+
resize_mode=resize_mode,
13331374
do_classifier_free_guidance=self.do_classifier_free_guidance,
13341375
guess_mode=guess_mode,
13351376
)
@@ -1341,10 +1382,15 @@ def __call__(
13411382
assert False
13421383

13431384
# 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
1344-
init_image = self.image_processor.preprocess(image, height=height, width=width)
1385+
original_image = image
1386+
init_image = self.image_processor.preprocess(
1387+
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
1388+
)
13451389
init_image = init_image.to(dtype=torch.float32)
13461390

1347-
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
1391+
mask = self.mask_processor.preprocess(
1392+
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1393+
)
13481394

13491395
masked_image = init_image * (mask < 0.5)
13501396
_, _, height, width = init_image.shape
@@ -1534,6 +1580,9 @@ def __call__(
15341580

15351581
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
15361582

1583+
if padding_mask_crop is not None:
1584+
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
1585+
15371586
# Offload all models
15381587
self.maybe_free_model_hooks()
15391588

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,9 +557,11 @@ def check_inputs(
557557
prompt,
558558
prompt_2,
559559
image,
560+
mask_image,
560561
strength,
561562
num_inference_steps,
562563
callback_steps,
564+
output_type,
563565
negative_prompt=None,
564566
negative_prompt_2=None,
565567
prompt_embeds=None,
@@ -570,6 +572,7 @@ def check_inputs(
570572
control_guidance_start=0.0,
571573
control_guidance_end=1.0,
572574
callback_on_step_end_tensor_inputs=None,
575+
padding_mask_crop=None,
573576
):
574577
if strength < 0 or strength > 1:
575578
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -632,6 +635,19 @@ def check_inputs(
632635
f" {negative_prompt_embeds.shape}."
633636
)
634637

638+
if padding_mask_crop is not None:
639+
if not isinstance(image, PIL.Image.Image):
640+
raise ValueError(
641+
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
642+
)
643+
if not isinstance(mask_image, PIL.Image.Image):
644+
raise ValueError(
645+
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
646+
f" {type(mask_image)}."
647+
)
648+
if output_type != "pil":
649+
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
650+
635651
if prompt_embeds is not None and pooled_prompt_embeds is None:
636652
raise ValueError(
637653
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
@@ -745,10 +761,14 @@ def prepare_control_image(
745761
num_images_per_prompt,
746762
device,
747763
dtype,
764+
crops_coords,
765+
resize_mode,
748766
do_classifier_free_guidance=False,
749767
guess_mode=False,
750768
):
751-
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
769+
image = self.control_image_processor.preprocess(
770+
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
771+
).to(dtype=torch.float32)
752772
image_batch_size = image.shape[0]
753773

754774
if image_batch_size == 1:
@@ -1066,6 +1086,7 @@ def __call__(
10661086
] = None,
10671087
height: Optional[int] = None,
10681088
width: Optional[int] = None,
1089+
padding_mask_crop: Optional[int] = None,
10691090
strength: float = 0.9999,
10701091
num_inference_steps: int = 50,
10711092
denoising_start: Optional[float] = None,
@@ -1121,6 +1142,12 @@ def __call__(
11211142
The height in pixels of the generated image.
11221143
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
11231144
The width in pixels of the generated image.
1145+
padding_mask_crop (`int`, *optional*, defaults to `None`):
1146+
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
1147+
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
1148+
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
1149+
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
1150+
and contain information inreleant for inpainging, such as background.
11241151
strength (`float`, *optional*, defaults to 0.9999):
11251152
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
11261153
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
@@ -1290,9 +1317,11 @@ def __call__(
12901317
prompt,
12911318
prompt_2,
12921319
control_image,
1320+
mask_image,
12931321
strength,
12941322
num_inference_steps,
12951323
callback_steps,
1324+
output_type,
12961325
negative_prompt,
12971326
negative_prompt_2,
12981327
prompt_embeds,
@@ -1303,6 +1332,7 @@ def __call__(
13031332
control_guidance_start,
13041333
control_guidance_end,
13051334
callback_on_step_end_tensor_inputs,
1335+
padding_mask_crop,
13061336
)
13071337

13081338
self._guidance_scale = guidance_scale
@@ -1370,7 +1400,18 @@ def denoising_value_valid(dnv):
13701400

13711401
# 5. Preprocess mask and image - resizes image and mask w.r.t height and width
13721402
# 5.1 Prepare init image
1373-
init_image = self.image_processor.preprocess(image, height=height, width=width)
1403+
if padding_mask_crop is not None:
1404+
height, width = self.image_processor.get_default_height_width(image, height, width)
1405+
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
1406+
resize_mode = "fill"
1407+
else:
1408+
crops_coords = None
1409+
resize_mode = "default"
1410+
1411+
original_image = image
1412+
init_image = self.image_processor.preprocess(
1413+
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
1414+
)
13741415
init_image = init_image.to(dtype=torch.float32)
13751416

13761417
# 5.2 Prepare control images
@@ -1383,6 +1424,8 @@ def denoising_value_valid(dnv):
13831424
num_images_per_prompt=num_images_per_prompt,
13841425
device=device,
13851426
dtype=controlnet.dtype,
1427+
crops_coords=crops_coords,
1428+
resize_mode=resize_mode,
13861429
do_classifier_free_guidance=self.do_classifier_free_guidance,
13871430
guess_mode=guess_mode,
13881431
)
@@ -1398,6 +1441,8 @@ def denoising_value_valid(dnv):
13981441
num_images_per_prompt=num_images_per_prompt,
13991442
device=device,
14001443
dtype=controlnet.dtype,
1444+
crops_coords=crops_coords,
1445+
resize_mode=resize_mode,
14011446
do_classifier_free_guidance=self.do_classifier_free_guidance,
14021447
guess_mode=guess_mode,
14031448
)
@@ -1409,7 +1454,9 @@ def denoising_value_valid(dnv):
14091454
raise ValueError(f"{controlnet.__class__} is not supported.")
14101455

14111456
# 5.3 Prepare mask
1412-
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
1457+
mask = self.mask_processor.preprocess(
1458+
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1459+
)
14131460

14141461
masked_image = init_image * (mask < 0.5)
14151462
_, _, height, width = init_image.shape
@@ -1684,6 +1731,9 @@ def denoising_value_valid(dnv):
16841731

16851732
image = self.image_processor.postprocess(image, output_type=output_type)
16861733

1734+
if padding_mask_crop is not None:
1735+
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
1736+
16871737
# Offload all models
16881738
self.maybe_free_model_hooks()
16891739

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,7 @@ def check_inputs(
642642
width,
643643
strength,
644644
callback_steps,
645+
output_type,
645646
negative_prompt=None,
646647
prompt_embeds=None,
647648
negative_prompt_embeds=None,
@@ -693,11 +694,6 @@ def check_inputs(
693694
f" {negative_prompt_embeds.shape}."
694695
)
695696
if padding_mask_crop is not None:
696-
if self.unet.config.in_channels != 4:
697-
raise ValueError(
698-
f"The UNet should have 4 input channels for inpainting mask crop, but has"
699-
f" {self.unet.config.in_channels} input channels."
700-
)
701697
if not isinstance(image, PIL.Image.Image):
702698
raise ValueError(
703699
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
@@ -707,6 +703,8 @@ def check_inputs(
707703
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
708704
f" {type(mask_image)}."
709705
)
706+
if output_type != "pil":
707+
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
710708

711709
def prepare_latents(
712710
self,
@@ -1166,6 +1164,7 @@ def __call__(
11661164
width,
11671165
strength,
11681166
callback_steps,
1167+
output_type,
11691168
negative_prompt,
11701169
prompt_embeds,
11711170
negative_prompt_embeds,

0 commit comments

Comments
 (0)