Skip to content

Commit 556fa26

Browse files
committed
Add linear interpolation between two prompts
1 parent 4ec2873 commit 556fa26

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import warnings
3+
import random
34
from typing import List, Optional, Union
45

56
import torch
@@ -163,3 +164,126 @@ def __call__(
163164
image = self.numpy_to_pil(image)
164165

165166
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
167+
168+
def get_text_latent_space(self, prompt):
169+
170+
# get prompt text embeddings
171+
text_input = self.tokenizer(
172+
prompt,
173+
padding="max_length",
174+
max_length=self.tokenizer.model_max_length,
175+
truncation=True,
176+
return_tensors="pt",
177+
)
178+
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
179+
return text_embeddings
180+
181+
def lerp_between_prompts(self, first_prompt, second_prompt, seed = None, length = 10, save=False, **kwargs):
182+
first_embedding = self.get_text_latent_space(first_prompt)
183+
second_embedding = self.get_text_latent_space(second_prompt)
184+
if not seed:
185+
seed = random.randint()
186+
generator = torch.Generator("cuda")
187+
lerp_embed_points = []
188+
for i in range(length):
189+
weight = i / length
190+
tensor_lerp = torch.lerp(first_embedding, second_embedding, weight)
191+
lerp_embed_points.extend(tensor_lerp)
192+
images = []
193+
for idx, latent_point in enumerate(lerp_embed_points):
194+
generator.manual_seed(seed)
195+
image = self.image_from_latent_space(latent_point, **kwargs)
196+
images.extend(image)
197+
if save:
198+
image.save(f"{first_prompt}-{second_prompt}-{idx:02d}")
199+
return images
200+
201+
202+
def image_from_latent_space(self, text_embeddings,
203+
height: Optional[int] = 512,
204+
width: Optional[int] = 512,
205+
num_inference_steps: Optional[int] = 50,
206+
guidance_scale: Optional[float] = 7.5,
207+
eta: Optional[float] = 0.0,
208+
generator: Optional[torch.Generator] = None,
209+
output_type: Optional[str] = "pil",
210+
**kwargs,):
211+
212+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
213+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
214+
# corresponds to doing no classifier free guidance.
215+
do_classifier_free_guidance = guidance_scale > 1.0
216+
# get unconditional embeddings for classifier free guidance
217+
if do_classifier_free_guidance:
218+
max_length = text_input.input_ids.shape[-1]
219+
uncond_input = self.tokenizer(
220+
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
221+
)
222+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
223+
224+
# For classifier free guidance, we need to do two forward passes.
225+
# Here we concatenate the unconditional and text embeddings into a single batch
226+
# to avoid doing two forward passes
227+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
228+
229+
# get the intial random noise
230+
latents = torch.randn(
231+
(batch_size, self.unet.in_channels, height // 8, width // 8),
232+
generator=generator,
233+
device=self.device,
234+
)
235+
236+
# set timesteps
237+
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
238+
extra_set_kwargs = {}
239+
if accepts_offset:
240+
extra_set_kwargs["offset"] = 1
241+
242+
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
243+
244+
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
245+
if isinstance(self.scheduler, LMSDiscreteScheduler):
246+
latents = latents * self.scheduler.sigmas[0]
247+
248+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
249+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
250+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
251+
# and should be between [0, 1]
252+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
253+
extra_step_kwargs = {}
254+
if accepts_eta:
255+
extra_step_kwargs["eta"] = eta
256+
257+
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
258+
# expand the latents if we are doing classifier free guidance
259+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
260+
if isinstance(self.scheduler, LMSDiscreteScheduler):
261+
sigma = self.scheduler.sigmas[i]
262+
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
263+
264+
# predict the noise residual
265+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
266+
267+
# perform guidance
268+
if do_classifier_free_guidance:
269+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
270+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
271+
272+
# compute the previous noisy sample x_t -> x_t-1
273+
if isinstance(self.scheduler, LMSDiscreteScheduler):
274+
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
275+
else:
276+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
277+
278+
# scale and decode the image latents with vae
279+
latents = 1 / 0.18215 * latents
280+
image = self.vae.decode(latents)
281+
282+
image = (image / 2 + 0.5).clamp(0, 1)
283+
image = image.cpu().permute(0, 2, 3, 1).numpy()
284+
285+
if output_type == "pil":
286+
image = self.numpy_to_pil(image)
287+
288+
return image
289+

0 commit comments

Comments
 (0)