Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ To achieve optimal results with OmniGen2, you can adjust the following key hyper
- `num_inference_step`: Number of discretization steps for the ODE solver. Default is `50`.
- `enable_teacache`: Whether or not enable [teacache](https://github.com/ali-vilab/TeaCache) for faster inference.
- `teacache_rel_l1_thresh`: The threshold for accumulated L1 distance for the timestep embedding-modulated noisy input. It serves as an indicator of whether to cache the model output. You can modify the `teacache_rel_l1_thresh` parameter to achieve your desired trade-off between latency and visual quality. The default value of 0.05 provides approximately a **30% speedup** compared to the baseline. Increasing this value can further reduce latency, but may result in some loss of detail.
- `enable_taylorseer`: Whether or not enable [taylorseer](https://github.com/Shenyi-Z/TaylorSeer) for faster inference. When enabled, inference speed can improve by up to 2X, with negligible quality loss compared to the baseline.

**Some suggestions for improving generation quality:**
1. Use High-Quality Images
Expand Down
13 changes: 12 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ def parse_args() -> argparse.Namespace:
default=0.05,
help="Relative L1 threshold for teacache."
)
parser.add_argument(
"--enable_taylorseer",
action="store_true",
help="Enable TaylorSeer Caching."
)
return parser.parse_args()

def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dtype: torch.dtype) -> OmniGen2Pipeline:
Expand All @@ -190,7 +195,12 @@ def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dty
print(f"LoRA weights loaded from {args.transformer_lora_path}")
pipeline.load_lora_weights(args.transformer_lora_path)

if args.enable_teacache:
if args.enable_teacache and args.enable_taylorseer:
print("WARNING: enable_teacache and enable_taylorseer are mutually exclusive. enable_teacache will be ignored.")

if args.enable_taylorseer:
pipeline.enable_taylorseer = True
elif args.enable_teacache:
pipeline.transformer.enable_teacache = True
pipeline.transformer.teacache_rel_l1_thresh = args.teacache_rel_l1_thresh

Expand All @@ -214,6 +224,7 @@ def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dty
apply_group_offloading(pipeline.vae, onload_device=accelerator.device, offload_type="block_level", num_blocks_per_group=2, use_stream=True)
else:
pipeline = pipeline.to(accelerator.device)

return pipeline

def preprocess(input_image_path: List[str] = []) -> Tuple[str, str, List[Image.Image]]:
Expand Down
3 changes: 3 additions & 0 deletions omnigen2/cache_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .cache_init import cache_init
from .cal_type import cal_type
from .force_scheduler import force_scheduler
38 changes: 38 additions & 0 deletions omnigen2/cache_functions/cache_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Modified from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/cache_init.py

# Type hinting would cause circular import, self should be `OmniGen2Pipeline`
def cache_init(self, num_steps: int):
'''
Initialization for cache.
'''
cache_dic = {}
cache = {}
cache_index = {}
cache[-1]={}
cache_index[-1]={}
cache_index['layer_index']={}
cache[-1]['layers_stream']={}
cache_dic['cache_counter'] = 0

for j in range(len(self.transformer.layers)):
cache[-1]['layers_stream'][j] = {}
cache_index[-1][j] = {}

cache_dic['Delta-DiT'] = False
cache_dic['cache_type'] = 'random'
cache_dic['cache_index'] = cache_index
cache_dic['cache'] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa'
cache_dic['fresh_ratio'] = 0.0
cache_dic['fresh_threshold'] = 3
cache_dic['soft_fresh_weight'] = 0.0
cache_dic['taylor_cache'] = True
cache_dic['max_order'] = 6
cache_dic['first_enhance'] = 5

current = {}
current['activated_steps'] = [0]
current['step'] = 0
current['num_steps'] = num_steps

return cache_dic, current
41 changes: 41 additions & 0 deletions omnigen2/cache_functions/cal_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copied from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/cal_type.py

from .force_scheduler import force_scheduler

def cal_type(cache_dic, current):
'''
Determine calculation type for this step
'''
if (cache_dic['fresh_ratio'] == 0.0) and (not cache_dic['taylor_cache']):
# FORA:Uniform
first_step = (current['step'] == 0)
else:
# ToCa: First enhanced
first_step = (current['step'] < cache_dic['first_enhance'])

if not first_step:
fresh_interval = cache_dic['cal_threshold']
else:
fresh_interval = cache_dic['fresh_threshold']

if (first_step) or (cache_dic['cache_counter'] == fresh_interval - 1 ):
current['type'] = 'full'
cache_dic['cache_counter'] = 0
current['activated_steps'].append(current['step'])
force_scheduler(cache_dic, current)

elif (cache_dic['taylor_cache']):
cache_dic['cache_counter'] += 1
current['type'] = 'Taylor'


elif (cache_dic['cache_counter'] % 2 == 1): # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
cache_dic['cache_counter'] += 1
current['type'] = 'ToCa'
# 'cache_noise' 'ToCa' 'FORA'
elif cache_dic['Delta-DiT']:
cache_dic['cache_counter'] += 1
current['type'] = 'Delta-Cache'
else:
cache_dic['cache_counter'] += 1
current['type'] = 'ToCa'
19 changes: 19 additions & 0 deletions omnigen2/cache_functions/force_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copied from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/force_scheduler.py

import torch

def force_scheduler(cache_dic, current):
if cache_dic['fresh_ratio'] == 0:
# FORA
linear_step_weight = 0.0
else:
# TokenCache
linear_step_weight = 0.0
step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current['step'] / current['num_steps'])
threshold = torch.round(cache_dic['fresh_threshold'] / step_factor)

# no force constrain for sensitive steps, cause the performance is good enough.
# you may have a try.

cache_dic['cal_threshold'] = threshold
#return threshold
106 changes: 81 additions & 25 deletions omnigen2/models/transformers/transformer_omnigen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
else:
from torch.nn import RMSNorm

from ...taylorseer_utils import derivative_approximation, taylor_formula, taylor_cache_init
from ...cache_functions import cache_init, cal_type

logger = logging.get_logger(__name__)

class OmniGen2TransformerBlock(nn.Module):
Expand Down Expand Up @@ -149,32 +152,69 @@ def forward(
Returns:
torch.Tensor: Output hidden states after transformer block processing
"""
import time
if self.modulation:
if temb is None:
raise ValueError("temb must be provided when modulation is enabled")

norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
enable_taylorseer = getattr(self, 'enable_taylorseer', False)
if enable_taylorseer:
if self.modulation:
if temb is None:
raise ValueError("temb must be provided when modulation is enabled")

if self.current['type'] == 'full':
self.current['module'] = 'total'
taylor_cache_init(cache_dic=self.cache_dic, current=self.current)

norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)

derivative_approximation(cache_dic=self.cache_dic, current=self.current, feature=hidden_states)

elif self.current['type'] == 'Taylor':
self.current['module'] = 'total'
hidden_states = taylor_formula(cache_dic=self.cache_dic, current=self.current)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
if self.modulation:
if temb is None:
raise ValueError("temb must be provided when modulation is enabled")

norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)

return hidden_states

Expand Down Expand Up @@ -516,6 +556,10 @@ def forward(
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = False,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
enable_taylorseer = getattr(self, 'enable_taylorseer', False)
if enable_taylorseer:
cal_type(self.cache_dic, self.current)

if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
Expand Down Expand Up @@ -632,7 +676,16 @@ def forward(
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
self.teacache_params.previous_residual = hidden_states - ori_hidden_states
else:
if enable_taylorseer:
self.current['stream'] = 'layers_stream'

for layer_idx, layer in enumerate(self.layers):
if enable_taylorseer:
layer.current = self.current
layer.cache_dic = self.cache_dic
layer.enable_taylorseer = True
self.current['layer'] = layer_idx

if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
layer, hidden_states, attention_mask, rotary_emb, temb
Expand All @@ -654,6 +707,9 @@ def forward(
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)

if enable_taylorseer:
self.current['step'] += 1

if not return_dict:
return output
Expand Down
62 changes: 37 additions & 25 deletions omnigen2/pipelines/omnigen2/pipeline_omnigen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
else:
XLA_AVAILABLE = False

from ...cache_functions import cache_init

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -635,17 +636,25 @@ def processing(
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

# Use different TeaCacheParams for different conditions
if self.transformer.enable_teacache:

enable_taylorseer = getattr(self, "enable_taylorseer", False)
if enable_taylorseer:
model_pred_cache_dic, model_pred_current = cache_init(self, num_inference_steps)
model_pred_ref_cache_dic, model_pred_ref_current = cache_init(self, num_inference_steps)
model_pred_uncond_cache_dic, model_pred_uncond_current = cache_init(self, num_inference_steps)
self.transformer.enable_taylorseer = True
elif self.transformer.enable_teacache:
# Use different TeaCacheParams for different conditions
teacache_params = TeaCacheParams()
teacache_params_uncond = TeaCacheParams()
teacache_params_ref = TeaCacheParams()

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):

if self.transformer.enable_teacache:
if enable_taylorseer:
self.transformer.cache_dic = model_pred_cache_dic
self.transformer.current = model_pred_current
elif self.transformer.enable_teacache:
teacache_params.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
self.transformer.teacache_params = teacache_params

Expand All @@ -661,8 +670,10 @@ def processing(
image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0

if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:

if self.transformer.enable_teacache:
if enable_taylorseer:
self.transformer.cache_dic = model_pred_ref_cache_dic
self.transformer.current = model_pred_ref_current
elif self.transformer.enable_teacache:
teacache_params_ref.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
self.transformer.teacache_params = teacache_params_ref

Expand All @@ -675,28 +686,29 @@ def processing(
ref_image_hidden_states=ref_latents,
)

if image_guidance_scale != 1:

if self.transformer.enable_teacache:
teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
self.transformer.teacache_params = teacache_params_uncond
if enable_taylorseer:
self.transformer.cache_dic = model_pred_uncond_cache_dic
self.transformer.current = model_pred_uncond_current
elif self.transformer.enable_teacache:
teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
self.transformer.teacache_params = teacache_params_uncond

model_pred_uncond = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=None,
)
else:
model_pred_uncond = torch.zeros_like(model_pred)
model_pred_uncond = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=None,
)

model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
text_guidance_scale * (model_pred - model_pred_ref)
text_guidance_scale * (model_pred - model_pred_ref)
elif text_guidance_scale > 1.0:

if self.transformer.enable_teacache:
if enable_taylorseer:
self.transformer.cache_dic = model_pred_uncond_cache_dic
self.transformer.current = model_pred_uncond_current
elif self.transformer.enable_teacache:
teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
self.transformer.teacache_params = teacache_params_uncond

Expand Down
Loading