Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update launch script
  • Loading branch information
vicgalle committed Oct 8, 2022
commit f9a4c7ccc00f19fe11ebfc86a2c80fb58830f10e
123 changes: 86 additions & 37 deletions scripts/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from transformers import AutoFeatureExtractor


Expand Down Expand Up @@ -68,7 +70,7 @@ def load_model_from_config(config, ckpt, verbose=False):
def put_watermark(img, wm_encoder=None):
if wm_encoder is not None:
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
img = wm_encoder.encode(img, 'dwtDct')
img = wm_encoder.encode(img, "dwtDct")
img = Image.fromarray(img[:, :, ::-1])
return img

Expand All @@ -77,16 +79,20 @@ def load_replacement(x):
try:
hwc = x.shape
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
y = (np.array(y)/255.0).astype(x.dtype)
y = (np.array(y) / 255.0).astype(x.dtype)
assert y.shape == x.shape
return y
except Exception:
return x


def check_safety(x_image):
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
safety_checker_input = safety_feature_extractor(
numpy_to_pil(x_image), return_tensors="pt"
)
x_checked_image, has_nsfw_concept = safety_checker(
images=x_image, clip_input=safety_checker_input.pixel_values
)
assert x_checked_image.shape[0] == len(has_nsfw_concept)
for i in range(len(has_nsfw_concept)):
if has_nsfw_concept[i]:
Expand All @@ -102,23 +108,23 @@ def main():
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
help="the prompt to render",
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2img-samples"
default="outputs/txt2img-samples",
)
parser.add_argument(
"--skip_grid",
action='store_true',
action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
"--skip_save",
action='store_true',
action="store_true",
help="do not save individual samples. For speed measurements.",
)
parser.add_argument(
Expand All @@ -129,17 +135,17 @@ def main():
)
parser.add_argument(
"--plms",
action='store_true',
action="store_true",
help="use plms sampling",
)
parser.add_argument(
"--laion400m",
action='store_true',
action="store_true",
help="uses the LAION400M model",
)
parser.add_argument(
"--fixed_code",
action='store_true',
action="store_true",
help="if enabled, uses the same starting code across samples ",
)
parser.add_argument(
Expand Down Expand Up @@ -204,7 +210,7 @@ def main():
parser.add_argument(
"--config",
type=str,
default="configs/stable-diffusion/v1-inference.yaml",
default="configs/stable-diffusion/v1-inference-aesthetic.yaml",
help="path to config which constructs model",
)
parser.add_argument(
Expand All @@ -224,7 +230,25 @@ def main():
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="autocast"
default="autocast",
)
parser.add_argument(
"--aesthetic_steps",
type=int,
help="number of steps for the aesthetic personalization",
default=5,
)
parser.add_argument(
"--aesthetic_lr",
type=int,
help="learning rate for the aesthetic personalization",
default=0.0001,
)
parser.add_argument(
"--aesthetic_embedding",
type=str,
help="aesthetic embedding file",
default="aesthetic_embeddings/sac_8plus.pt",
)
opt = parser.parse_args()

Expand All @@ -237,6 +261,14 @@ def main():
seed_everything(opt.seed)

config = OmegaConf.load(f"{opt.config}")

# Override config with personalization arguments
config.model.params.cond_stage_config.params.T = opt.aesthetic_steps
config.model.params.cond_stage_config.params.lr = opt.aesthetic_lr
config.model.params.cond_stage_config.params.aesthetic_embedding_path = (
opt.aesthetic_embedding
)

model = load_model_from_config(config, f"{opt.ckpt}")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
Expand All @@ -250,10 +282,12 @@ def main():
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir

print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
print(
"Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)..."
)
wm = "StableDiffusionV1"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
wm_encoder.set_watermark("bytes", wm.encode("utf-8"))

batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
Expand All @@ -275,9 +309,11 @@ def main():

start_code = None
if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
start_code = torch.randn(
[opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device
)

precision_scope = autocast if opt.precision=="autocast" else nullcontext
precision_scope = autocast if opt.precision == "autocast" else nullcontext
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
Expand All @@ -292,30 +328,42 @@ def main():
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code)
samples_ddim, _ = sampler.sample(
S=opt.ddim_steps,
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code,
)

x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
x_samples_ddim = torch.clamp(
(x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0
)
x_samples_ddim = (
x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
)

x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)

x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
x_checked_image_torch = torch.from_numpy(
x_checked_image
).permute(0, 3, 1, 2)

if not opt.skip_save:
for x_sample in x_checked_image_torch:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = 255.0 * rearrange(
x_sample.cpu().numpy(), "c h w -> h w c"
)
img = Image.fromarray(x_sample.astype(np.uint8))
img = put_watermark(img, wm_encoder)
img.save(os.path.join(sample_path, f"{base_count:05}.png"))
img.save(
os.path.join(sample_path, f"{base_count:05}.png")
)
base_count += 1

if not opt.skip_grid:
Expand All @@ -324,20 +372,21 @@ def main():
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = rearrange(grid, "n b c h w -> (n b) c h w")
grid = make_grid(grid, nrow=n_rows)

# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img = put_watermark(img, wm_encoder)
img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
img.save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1

toc = time.time()

print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f" \nEnjoy.")
print(
f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy."
)


if __name__ == "__main__":
Expand Down