Skip to content

Commit e31f38b

Browse files
authored
[docs] Remove attention slicing (huggingface#4518)
* remove attention slicing * apply feedback
1 parent 3bd5e07 commit e31f38b

File tree

2 files changed

+29
-37
lines changed

2 files changed

+29
-37
lines changed

docs/source/en/optimization/fp16.md

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -65,42 +65,11 @@ image = pipe(prompt).images[0]
6565

6666
</Tip>
6767

68-
## Sliced attention for additional memory savings
69-
70-
For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once.
71-
72-
<Tip>
73-
Attention slicing is useful even if a batch size of just 1 is used - as long
74-
as the model uses more than one attention head. If there is more than one
75-
attention head the *QK^T* attention matrix can be computed sequentially for
76-
each head which can save a significant amount of memory.
77-
</Tip>
78-
79-
To perform the attention computation sequentially over each head, you only need to invoke [`~DiffusionPipeline.enable_attention_slicing`] in your pipeline before inference, like here:
80-
81-
```Python
82-
import torch
83-
from diffusers import DiffusionPipeline
84-
85-
pipe = DiffusionPipeline.from_pretrained(
86-
"runwayml/stable-diffusion-v1-5",
87-
torch_dtype=torch.float16,
88-
)
89-
pipe = pipe.to("cuda")
90-
91-
prompt = "a photo of an astronaut riding a horse on mars"
92-
pipe.enable_attention_slicing()
93-
image = pipe(prompt).images[0]
94-
```
95-
96-
There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
97-
98-
9968
## Sliced VAE decode for larger batches
10069

10170
To decode large batches of images with limited VRAM, or to enable batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time.
10271

103-
You likely want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
72+
You likely want to couple this with [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
10473

10574
To perform the VAE decode one image at a time, invoke [`~StableDiffusionPipeline.enable_vae_slicing`] in your pipeline before inference. For example:
10675

@@ -126,7 +95,7 @@ You may see a small performance boost in VAE decode on multi-image batches. Ther
12695

12796
Tiled VAE processing makes it possible to work with large images on limited VRAM. For example, generating 4k images in 8GB of VRAM. Tiled VAE decoder splits the image into overlapping tiles, decodes the tiles, and blends the outputs to make the final image.
12897

129-
You want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
98+
You want to couple this with [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
13099

131100
To use tiled VAE processing, invoke [`~StableDiffusionPipeline.enable_vae_tiling`] in your pipeline before inference. For example:
132101

@@ -192,7 +161,6 @@ pipe = StableDiffusionPipeline.from_pretrained(
192161

193162
prompt = "a photo of an astronaut riding a horse on mars"
194163
pipe.enable_sequential_cpu_offload()
195-
pipe.enable_attention_slicing(1)
196164

197165
image = pipe(prompt).images[0]
198166
```
@@ -241,7 +209,6 @@ pipe = StableDiffusionPipeline.from_pretrained(
241209

242210
prompt = "a photo of an astronaut riding a horse on mars"
243211
pipe.enable_model_cpu_offload()
244-
pipe.enable_attention_slicing(1)
245212

246213
image = pipe(prompt).images[0]
247214
```

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,15 +1669,40 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
16691669
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
16701670
r"""
16711671
Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor
1672-
in slices to compute attention in several steps. This is useful to save some memory in exchange for a small
1673-
speed decrease.
1672+
in slices to compute attention in several steps. For more than one attention head, the computation is performed
1673+
sequentially over each head. This is useful to save some memory in exchange for a small speed decrease.
1674+
1675+
<Tip warning={true}>
1676+
1677+
⚠️ Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch
1678+
2.0 or xFormers. These attention computations are already very memory efficient so you won't need to enable
1679+
this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious slow downs!
1680+
1681+
</Tip>
16741682
16751683
Args:
16761684
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
16771685
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
16781686
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
16791687
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
16801688
must be a multiple of `slice_size`.
1689+
1690+
Examples:
1691+
1692+
```py
1693+
>>> import torch
1694+
>>> from diffusers import StableDiffusionPipeline
1695+
1696+
>>> pipe = StableDiffusionPipeline.from_pretrained(
1697+
... "runwayml/stable-diffusion-v1-5",
1698+
... torch_dtype=torch.float16,
1699+
... use_safetensors=True,
1700+
... )
1701+
1702+
>>> prompt = "a photo of an astronaut riding a horse on mars"
1703+
>>> pipe.enable_attention_slicing()
1704+
>>> image = pipe(prompt).images[0]
1705+
```
16811706
"""
16821707
self.set_attention_slice(slice_size)
16831708

0 commit comments

Comments
 (0)