@@ -683,16 +683,19 @@ def check_inputs(
683
683
self ,
684
684
prompt ,
685
685
image ,
686
+ mask_image ,
686
687
height ,
687
688
width ,
688
689
callback_steps ,
690
+ output_type ,
689
691
negative_prompt = None ,
690
692
prompt_embeds = None ,
691
693
negative_prompt_embeds = None ,
692
694
controlnet_conditioning_scale = 1.0 ,
693
695
control_guidance_start = 0.0 ,
694
696
control_guidance_end = 1.0 ,
695
697
callback_on_step_end_tensor_inputs = None ,
698
+ padding_mask_crop = None ,
696
699
):
697
700
if height is not None and height % 8 != 0 or width is not None and width % 8 != 0 :
698
701
raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -736,6 +739,19 @@ def check_inputs(
736
739
f" { negative_prompt_embeds .shape } ."
737
740
)
738
741
742
+ if padding_mask_crop is not None :
743
+ if not isinstance (image , PIL .Image .Image ):
744
+ raise ValueError (
745
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" { type (image )} ."
746
+ )
747
+ if not isinstance (mask_image , PIL .Image .Image ):
748
+ raise ValueError (
749
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
750
+ f" { type (mask_image )} ."
751
+ )
752
+ if output_type != "pil" :
753
+ raise ValueError (f"The output type should be PIL when inpainting mask crop, but is" f" { output_type } ." )
754
+
739
755
# `prompt` needs more sophisticated handling when there are multiple
740
756
# conditionings.
741
757
if isinstance (self .controlnet , MultiControlNetModel ):
@@ -862,7 +878,6 @@ def check_image(self, image, prompt, prompt_embeds):
862
878
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: { image_batch_size } , prompt batch size: { prompt_batch_size } "
863
879
)
864
880
865
- # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
866
881
def prepare_control_image (
867
882
self ,
868
883
image ,
@@ -872,10 +887,14 @@ def prepare_control_image(
872
887
num_images_per_prompt ,
873
888
device ,
874
889
dtype ,
890
+ crops_coords ,
891
+ resize_mode ,
875
892
do_classifier_free_guidance = False ,
876
893
guess_mode = False ,
877
894
):
878
- image = self .control_image_processor .preprocess (image , height = height , width = width ).to (dtype = torch .float32 )
895
+ image = self .control_image_processor .preprocess (
896
+ image , height = height , width = width , crops_coords = crops_coords , resize_mode = resize_mode
897
+ ).to (dtype = torch .float32 )
879
898
image_batch_size = image .shape [0 ]
880
899
881
900
if image_batch_size == 1 :
@@ -1074,6 +1093,7 @@ def __call__(
1074
1093
control_image : PipelineImageInput = None ,
1075
1094
height : Optional [int ] = None ,
1076
1095
width : Optional [int ] = None ,
1096
+ padding_mask_crop : Optional [int ] = None ,
1077
1097
strength : float = 1.0 ,
1078
1098
num_inference_steps : int = 50 ,
1079
1099
guidance_scale : float = 7.5 ,
@@ -1130,6 +1150,12 @@ def __call__(
1130
1150
The height in pixels of the generated image.
1131
1151
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1132
1152
The width in pixels of the generated image.
1153
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
1154
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
1155
+ `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
1156
+ contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
1157
+ the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
1158
+ and contain information inreleant for inpainging, such as background.
1133
1159
strength (`float`, *optional*, defaults to 1.0):
1134
1160
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
1135
1161
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
@@ -1240,16 +1266,19 @@ def __call__(
1240
1266
self .check_inputs (
1241
1267
prompt ,
1242
1268
control_image ,
1269
+ mask_image ,
1243
1270
height ,
1244
1271
width ,
1245
1272
callback_steps ,
1273
+ output_type ,
1246
1274
negative_prompt ,
1247
1275
prompt_embeds ,
1248
1276
negative_prompt_embeds ,
1249
1277
controlnet_conditioning_scale ,
1250
1278
control_guidance_start ,
1251
1279
control_guidance_end ,
1252
1280
callback_on_step_end_tensor_inputs ,
1281
+ padding_mask_crop ,
1253
1282
)
1254
1283
1255
1284
self ._guidance_scale = guidance_scale
@@ -1264,6 +1293,14 @@ def __call__(
1264
1293
else :
1265
1294
batch_size = prompt_embeds .shape [0 ]
1266
1295
1296
+ if padding_mask_crop is not None :
1297
+ height , width = self .image_processor .get_default_height_width (image , height , width )
1298
+ crops_coords = self .mask_processor .get_crop_region (mask_image , width , height , pad = padding_mask_crop )
1299
+ resize_mode = "fill"
1300
+ else :
1301
+ crops_coords = None
1302
+ resize_mode = "default"
1303
+
1267
1304
device = self ._execution_device
1268
1305
1269
1306
if isinstance (controlnet , MultiControlNetModel ) and isinstance (controlnet_conditioning_scale , float ):
@@ -1315,6 +1352,8 @@ def __call__(
1315
1352
num_images_per_prompt = num_images_per_prompt ,
1316
1353
device = device ,
1317
1354
dtype = controlnet .dtype ,
1355
+ crops_coords = crops_coords ,
1356
+ resize_mode = resize_mode ,
1318
1357
do_classifier_free_guidance = self .do_classifier_free_guidance ,
1319
1358
guess_mode = guess_mode ,
1320
1359
)
@@ -1330,6 +1369,8 @@ def __call__(
1330
1369
num_images_per_prompt = num_images_per_prompt ,
1331
1370
device = device ,
1332
1371
dtype = controlnet .dtype ,
1372
+ crops_coords = crops_coords ,
1373
+ resize_mode = resize_mode ,
1333
1374
do_classifier_free_guidance = self .do_classifier_free_guidance ,
1334
1375
guess_mode = guess_mode ,
1335
1376
)
@@ -1341,10 +1382,15 @@ def __call__(
1341
1382
assert False
1342
1383
1343
1384
# 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
1344
- init_image = self .image_processor .preprocess (image , height = height , width = width )
1385
+ original_image = image
1386
+ init_image = self .image_processor .preprocess (
1387
+ image , height = height , width = width , crops_coords = crops_coords , resize_mode = resize_mode
1388
+ )
1345
1389
init_image = init_image .to (dtype = torch .float32 )
1346
1390
1347
- mask = self .mask_processor .preprocess (mask_image , height = height , width = width )
1391
+ mask = self .mask_processor .preprocess (
1392
+ mask_image , height = height , width = width , resize_mode = resize_mode , crops_coords = crops_coords
1393
+ )
1348
1394
1349
1395
masked_image = init_image * (mask < 0.5 )
1350
1396
_ , _ , height , width = init_image .shape
@@ -1534,6 +1580,9 @@ def __call__(
1534
1580
1535
1581
image = self .image_processor .postprocess (image , output_type = output_type , do_denormalize = do_denormalize )
1536
1582
1583
+ if padding_mask_crop is not None :
1584
+ image = [self .image_processor .apply_overlay (mask_image , original_image , i , crops_coords ) for i in image ]
1585
+
1537
1586
# Offload all models
1538
1587
self .maybe_free_model_hooks ()
1539
1588
0 commit comments