Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remove redundant argument apply_t5_attn_mask
  • Loading branch information
duongve13112002 committed Feb 7, 2026
commit 96a3ae2f87fc69e8b564c64db8c749489ef83d00
3 changes: 0 additions & 3 deletions anima_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def train(args):
args.text_encoder_batch_size,
False,
False,
args.apply_t5_attn_mask,
)
)
train_dataset_group.set_current_strategies()
Expand Down Expand Up @@ -223,7 +222,6 @@ def train(args):
# Set text encoding strategy
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
apply_t5_attn_mask=args.apply_t5_attn_mask,
dropout_rate=caption_dropout_rate,
)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
Expand All @@ -243,7 +241,6 @@ def train(args):
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=False,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)

Expand Down
3 changes: 0 additions & 3 deletions anima_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def get_latents_caching_strategy(self, args):
def get_text_encoding_strategy(self, args):
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
self.text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
apply_t5_attn_mask=args.apply_t5_attn_mask,
dropout_rate=caption_dropout_rate,
)
return self.text_encoding_strategy
Expand All @@ -193,7 +192,6 @@ def get_text_encoder_outputs_caching_strategy(self, args):
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=False,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
return None

Expand Down Expand Up @@ -471,7 +469,6 @@ def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, is_stable_diffusion_ckpt=True)

def update_metadata(self, metadata, args):
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
metadata["ss_timestep_sample_method"] = getattr(args, 'timestep_sample_method', 'logit_normal')
Expand Down
5 changes: 0 additions & 5 deletions docs/anima_train_network.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,6 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
- Maximum token length for the Qwen3 tokenizer. Default `512`.
* `--t5_max_token_length=<integer>`
- Maximum token length for the T5 tokenizer. Default `512`.
* `--apply_t5_attn_mask`
- Apply attention mask to T5 tokens in the LLM adapter.
* `--flash_attn`
- Use Flash Attention for DiT self/cross-attention. Requires `pip install flash-attn`. Falls back to PyTorch SDPA if the package is not installed. Note: Flash Attention is only applied to DiT blocks; the LLM Adapter uses standard attention because it requires attention masks.
* `--transformer_dtype=<choice>`
Expand Down Expand Up @@ -229,7 +227,6 @@ Anima supports 6 independent learning rate groups. Set to `0` to freeze a compon
* `--sigmoid_scale` - logit_normalタイムステップサンプリングのスケール係数。デフォルト`1.0`。
* `--qwen3_max_token_length` - Qwen3トークナイザーの最大トークン長。デフォルト`512`。
* `--t5_max_token_length` - T5トークナイザーの最大トークン長。デフォルト`512`。
* `--apply_t5_attn_mask` - LLM AdapterでT5トークンにアテンションマスクを適用。
* `--flash_attn` - DiTのself/cross-attentionにFlash Attentionを使用。`pip install flash-attn`が必要。
* `--transformer_dtype` - Transformerブロック用の個別dtype。

Expand Down Expand Up @@ -537,7 +534,6 @@ Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレー

The following Anima-specific metadata is saved in the LoRA model file:

* `ss_apply_t5_attn_mask`
* `ss_weighting_scheme`
* `ss_discrete_flow_shift`
* `ss_timestep_sample_method`
Expand All @@ -552,7 +548,6 @@ The following Anima-specific metadata is saved in the LoRA model file:

以下のAnima固有のメタデータがLoRAモデルファイルに保存されます:

* `ss_apply_t5_attn_mask`
* `ss_weighting_scheme`
* `ss_discrete_flow_shift`
* `ss_timestep_sample_method`
Expand Down
5 changes: 0 additions & 5 deletions library/anima_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,6 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
default=512,
help="Maximum token length for T5 tokenizer (default: 512)",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="Apply attention mask to T5 tokens in LLM adapter",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
Expand Down
8 changes: 0 additions & 8 deletions library/strategy_anima.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,15 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):

def __init__(
self,
apply_t5_attn_mask: bool = False,
dropout_rate: float = 0.0,
) -> None:
self.apply_t5_attn_mask = apply_t5_attn_mask
self.dropout_rate = dropout_rate

def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_t5_attn_mask: Optional[bool] = None,
enable_dropout: bool = True,
) -> List[torch.Tensor]:
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
Expand All @@ -109,8 +106,6 @@ def encode_tokens(
Returns:
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
"""
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask

qwen3_text_encoder = models[0]
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
Expand Down Expand Up @@ -222,10 +217,8 @@ def __init__(
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_t5_attn_mask: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask

def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
Expand Down Expand Up @@ -279,7 +272,6 @@ def cache_batch_outputs(
tokenize_strategy,
models,
tokens_and_masks,
apply_t5_attn_mask=self.apply_t5_attn_mask,
enable_dropout=False,
)

Expand Down
4 changes: 1 addition & 3 deletions tests/test_anima_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def test_text_encoder_cache(args, pairs):
t5_max_length=args.t5_max_length,
)
text_encoding_strategy = AnimaTextEncodingStrategy(
apply_t5_attn_mask=False,
dropout_rate=0.0,
)

Expand Down Expand Up @@ -355,7 +354,6 @@ def test_text_encoder_cache(args, pairs):
# Test drop_cached_text_encoder_outputs
print(f"\n[2.8] Testing drop_cached_text_encoder_outputs (caption dropout)...")
dropout_strategy = AnimaTextEncodingStrategy(
apply_t5_attn_mask=False,
dropout_rate=0.5, # high rate to ensure some drops
)
dropped = dropout_strategy.drop_cached_text_encoder_outputs(*stacked)
Expand Down Expand Up @@ -401,7 +399,7 @@ def test_full_batch_simulation(args, pairs):
qwen3_tokenizer=qwen3_tokenizer, t5_tokenizer=t5_tokenizer,
qwen3_max_length=args.qwen3_max_length, t5_max_length=args.t5_max_length,
)
text_encoding_strategy = AnimaTextEncodingStrategy(apply_t5_attn_mask=False, dropout_rate=0.0)
text_encoding_strategy = AnimaTextEncodingStrategy(dropout_rate=0.0)

captions = [cap for _, cap in pairs]

Expand Down