Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
42658fa
Add Support for Z-Image.
JerryWu-code Nov 23, 2025
3e74bb2
Reformatting with make style, black & isort.
JerryWu-code Nov 23, 2025
a4b89a0
Remove init, Modify import utils, Merge forward in transformers block…
JerryWu-code Nov 24, 2025
7df350d
modified main model forward, freqs_cis left
ChrisLiu6 Nov 24, 2025
1dd587b
Merge remote-tracking branch 'JerryWu-code/z-image-dev' into fork/Jer…
ChrisLiu6 Nov 24, 2025
aae03cf
refactored to add B dim
ChrisLiu6 Nov 24, 2025
21d8130
fixed stack issue
ChrisLiu6 Nov 24, 2025
e3dfa9e
fixed modulation bug
ChrisLiu6 Nov 24, 2025
a7fa731
fixed modulation bug
ChrisLiu6 Nov 24, 2025
1e0cefe
fix bug
ChrisLiu6 Nov 24, 2025
7adaae8
remove value_from_time_aware_config
ChrisLiu6 Nov 24, 2025
5b4c907
styling
ChrisLiu6 Nov 24, 2025
2bb39f4
Fix neg embed and devide / bug; Reuse pad zero tensor; Turn cat -> re…
JerryWu-code Nov 24, 2025
71e8049
Replace padding with pad_sequence; Add gradient checkpointing.
JerryWu-code Nov 24, 2025
fbf26b7
Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, repl…
JerryWu-code Nov 24, 2025
6c0c059
Fix Docstring and Make Style.
JerryWu-code Nov 24, 2025
28685dd
Revert "Fix flash_attn3 in dispatch attn backend by _flash_attn_forwa…
ChrisLiu6 Nov 25, 2025
8e391b7
update z-image docstring
ChrisLiu6 Nov 25, 2025
3b22e84
Revert attention dispatcher
ChrisLiu6 Nov 25, 2025
3d1a7aa
update z-image docstring
ChrisLiu6 Nov 25, 2025
336c5ce
styling
ChrisLiu6 Nov 25, 2025
38a89ed
Recover attention_dispatch.py with its origin impl, later would speci…
JerryWu-code Nov 25, 2025
69d61e5
Fix prev bug, and support for prompt_embeds pass in args after prompt…
JerryWu-code Nov 25, 2025
549ad57
Merge branch 'z-image-dev-ql' into z-image-dev
JerryWu-code Nov 25, 2025
1dd8f3c
Remove einop dependency.
JerryWu-code Nov 25, 2025
2f2d8c3
Merge branch 'z-image-dev' into z-image
JerryWu-code Nov 25, 2025
a74a0c4
Merge remote-tracking branch 'origin/main' into z-image
JerryWu-code Nov 25, 2025
e49a1f9
remove redundant imports & make fix-copies
ChrisLiu6 Nov 25, 2025
1048d0a
fix import
ChrisLiu6 Nov 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
styling
  • Loading branch information
ChrisLiu6 committed Nov 24, 2025
commit 5b4c907407f30d16b2fcd948981a04d40bc4198e
26 changes: 5 additions & 21 deletions src/diffusers/models/transformers/transformer_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import math
from typing import List, Optional, Tuple

Expand All @@ -23,11 +22,11 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ..attention_dispatch import dispatch_attention_fn
from ...models.attention_processor import Attention
from ...models.modeling_utils import ModelMixin
from ...utils.import_utils import is_flash_attn_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_dispatch import dispatch_attention_fn


if is_flash_attn_available():
Expand Down Expand Up @@ -99,7 +98,6 @@ def __call__(
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
) -> torch.Tensor:

query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
Expand Down Expand Up @@ -586,11 +584,7 @@ def forward(
dtype=x_freqs_cis[0].dtype,
device=device,
)
x_attn_mask = torch.ones(
(bsz, x_max_item_seqlen),
dtype=torch.bool,
device=device
)
x_attn_mask = torch.ones((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, (item, freqs_item) in enumerate(zip(x, x_freqs_cis)):
seq_len = x_item_seqlens[i]
pad_len = x_max_item_seqlen - seq_len
Expand Down Expand Up @@ -629,11 +623,7 @@ def forward(
dtype=cap_freqs_cis[0].dtype,
device=device,
)
cap_attn_mask = torch.ones(
(bsz, cap_max_item_seqlen),
dtype=torch.bool,
device=device
)
cap_attn_mask = torch.ones((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, (item, freqs_item) in enumerate(zip(cap_feats, cap_freqs_cis)):
seq_len = cap_item_seqlens[i]
pad_len = cap_max_item_seqlen - seq_len
Expand Down Expand Up @@ -672,11 +662,7 @@ def forward(
dtype=unified_freqs_cis[0].dtype,
device=device,
)
unified_attn_mask = torch.ones(
(bsz, unified_max_item_seqlen),
dtype=torch.bool,
device=device
)
unified_attn_mask = torch.ones((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, (item, freqs_item) in enumerate(zip(unified, unified_freqs_cis)):
seq_len = unified_item_seqlens[i]
pad_len = unified_max_item_seqlen - seq_len
Expand All @@ -694,9 +680,7 @@ def forward(
adaln_input,
)

unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](
unified, adaln_input
)
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
unified = list(unified.unbind(dim=0))
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)

Expand Down