Skip to content

Commit e11bbe8

Browse files
committed
Fix graph breaks
1 parent 256e696 commit e11bbe8

File tree

4 files changed

+13
-6
lines changed

4 files changed

+13
-6
lines changed

src/diffusers/models/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ def forward(self, hidden_states):
192192
hidden_states = (hidden_states + residual) / self.rescale_output_factor
193193
return hidden_states
194194

195+
from torch._dynamo import allow_in_graph
195196

197+
@allow_in_graph
196198
class BasicTransformerBlock(nn.Module):
197199
r"""
198200
A basic Transformer block.

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
7777

7878
def get_parameter_dtype(parameter: torch.nn.Module):
7979
try:
80+
return tuple(parameter.parameters())[0].dtype
8081
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
8182
return next(parameters_and_buffers).dtype
8283
except StopIteration:

src/diffusers/models/unet_2d_blocks.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,8 @@ def forward(
560560
hidden_states,
561561
encoder_hidden_states=encoder_hidden_states,
562562
cross_attention_kwargs=cross_attention_kwargs,
563-
).sample
563+
return_dict=False,
564+
)[0]
564565
hidden_states = resnet(hidden_states, temb)
565566

566567
return hidden_states
@@ -868,7 +869,8 @@ def custom_forward(*inputs):
868869
hidden_states,
869870
encoder_hidden_states=encoder_hidden_states,
870871
cross_attention_kwargs=cross_attention_kwargs,
871-
).sample
872+
return_dict=False
873+
)[0]
872874

873875
output_states += (hidden_states,)
874876

@@ -1859,7 +1861,8 @@ def custom_forward(*inputs):
18591861
hidden_states,
18601862
encoder_hidden_states=encoder_hidden_states,
18611863
cross_attention_kwargs=cross_attention_kwargs,
1862-
).sample
1864+
return_dict=False
1865+
)[0]
18631866

18641867
if self.upsamplers is not None:
18651868
for upsampler in self.upsamplers:

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def run_safety_checker(self, image, device, dtype):
440440

441441
def decode_latents(self, latents):
442442
latents = 1 / self.vae.config.scaling_factor * latents
443-
image = self.vae.decode(latents).sample
443+
image = self.vae.decode(latents, return_dict=False)[0]
444444
image = (image / 2 + 0.5).clamp(0, 1)
445445
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
446446
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@@ -686,15 +686,16 @@ def __call__(
686686
t,
687687
encoder_hidden_states=prompt_embeds,
688688
cross_attention_kwargs=cross_attention_kwargs,
689-
).sample
689+
return_dict=False,
690+
)[0]
690691

691692
# perform guidance
692693
if do_classifier_free_guidance:
693694
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
694695
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
695696

696697
# compute the previous noisy sample x_t -> x_t-1
697-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
698+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False, **extra_step_kwargs)[0]
698699

699700
# call the callback, if provided
700701
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

0 commit comments

Comments
 (0)