Skip to content

[training, perf] fix: THD-aware FLOPS via cu_seqlens (Σᵢ sᵢ²)#3839

Open
cuichenx wants to merge 1 commit into
mainfrom
chcui/thd-flops-fix
Open

[training, perf] fix: THD-aware FLOPS via cu_seqlens (Σᵢ sᵢ²)#3839
cuichenx wants to merge 1 commit into
mainfrom
chcui/thd-flops-fix

Conversation

@cuichenx
Copy link
Copy Markdown
Contributor

Summary

The current FLOPS calculator treats every batch as BSHD with the attention term scaling as seq_length² (or pack_length² after #3529). For THD packed training — both offline-packed LLM SFT (via PackedSequenceSpecs) and VLM in-batch packing — the actual attention work is Σᵢ sᵢ² over the sub-sequence lengths within each pack, not pack_length². This PR threads cu_seqlens from the dataloader into the FLOPS accumulator so the attention term reports the true work.

The fix

New helper accumulate_flops_metadata() in src/megatron/bridge/training/utils/flop_utils.py:

  • Computes seqlen_sum (linear) and seqlen_sq_sum (quadratic, attention-term driver).
  • When cu_seqlens (preferring cu_seqlens_unpadded if present) is available, derives seqlen_sq_sum = Σᵢ sᵢ² over real sub-sequences; truncates the padded tail via cu_seqlens_argmin.
  • Falls back to the existing BSHD formula mbs * seq_len² when no cu_seqlens is provided (dense pretraining / non-packed paths) — bit-exact identical to pre-fix behavior.
  • Also accumulates vision_patches from optional image/video grid tensors.

Wired into four step functions (consumed by the existing train.py reset/read flow merged in #3529):

  • src/megatron/bridge/training/gpt_step.py — LLM, offline THD via PackedSequenceSpecs
  • src/megatron/bridge/training/vlm_step.py — legacy VLM in-batch packing
  • src/megatron/bridge/models/qwen_vl/qwen3_vl_step.py — Qwen3-VL family in-batch packing
  • src/megatron/bridge/models/qwen_omni/qwen3_omni_step.py — Qwen3-Omni (no packing today; calls helper for vision-patch tracking and BSHD seqlen²)

Why

For a packed batch of length L_pack holding N real sub-sequences of lengths s₁ … s_N:

  • BSHD approximation (old): attention FLOPS ∝ L_pack²
  • THD truth (new): attention FLOPS ∝ Σᵢ sᵢ²

Equality only when N == 1 (a single sub-seq fills the pack). For typical packed SFT or in-batch VLM packing, Σᵢ sᵢ² ≪ L_pack² because cross-sub-seq attention does not happen.

Verification

Two identical workspaces on cw-dfw — only the 5 source files differ. Same model, same dataset, same seed (rng.seed=42), same parallelism. Identical loss values and iteration times across paired runs confirm only the reported FLOPS changes.

Recipe seq Packing Baseline (origin/main) Fix (this PR) Baseline over-count
qwen3_8b_sft_config 2048 offline THD (SQuAD) 162.6 TFLOP/s/GPU 155.8 TFLOP/s/GPU +4.4%
qwen3_8b_sft_config 4096 offline THD (SQuAD) 339.9 TFLOP/s/GPU 156.7 TFLOP/s/GPU +117% (2.17×)
qwen35_vl_9b_sft_config 4096 in-batch (CORD-v2) 261.6 TFLOP/s/GPU 88.9 TFLOP/s/GPU +194% (~3×)

Means computed over 26 post-warmup iters (iters 5-30).

The seq=2048 vs seq=4096 pair on the same LLM recipe is the cleanest demonstration: the fix reports near-identical TFLOP/s/GPU (155.8 vs 156.7) because actual attention work is determined by per-sample lengths, not pack length. The baseline doubles (162.6 → 339.9) because its pack_length² formula scales quadratically with the pack size while the actual work doesn't — a clear, growing over-count.

The VLM case shows the largest over-count because in-batch packing concentrates many short multimodal samples into each 4K pack, where pack_length² is dramatically larger than Σᵢ sᵢ².

Sanity checks (per run, in log header):

mcore: /chcui/megatron-lm-mainHEAD/megatron/core/__init__.py
bridge: /chcui/Megatron-Bridge[-flops-baseline]/src/megatron/bridge/__init__.py
has accumulate_flops_metadata: True  (fix) / False (baseline)

Tests

9 new unit tests in tests/unit_tests/training/utils/test_flop_utils.py::TestAccumulateFlopsMetadata cover:

  • BSHD fallback (no cu_seqlens) → mbs * seq_len²
  • THD with cu_seqlens → Σᵢ sᵢ²
  • Padded cu_seqlens with cu_seqlens_argmin truncation
  • cu_seqlens_unpadded precedence over padded variant
  • Additive accumulation across micro-batches
  • Tokens-None no-op
  • Image + video grid_thw accumulation
  • Empty cu_seqlens fallback to BSHD
  • Regression check: 32-sample pack at length 8192 → 32× smaller attention work than BSHD approximation

Full 62-test test_flop_utils.py suite passes (53 pre-existing + 9 new).

Out of scope

  • llava_step.py and audio_lm_step.py still fall back to legacy cfg.model.seq_length (helper not wired). Easy follow-up.
  • Hybrid model path (hybrid_flops in flop_utils) still passes scalar seq_len, dropping quadratic accuracy for hybrid attention layers — independent of this PR.
  • CP > 1 has a pre-existing inconsistency between linear (mbs * tokens.shape[1], per-CP-rank) and quadratic (Σ sᵢ² from full cu_seqlens, global) accumulator scale. Not introduced here; not exercised by the verification runs (CP=1).

Labels

bug · area:perf · area:training · needs-review

Packed THD training (offline-packed LLM SFT and VLM in-batch packing)
over-counts attention FLOPS by treating the whole pack as one length-
seq_length sequence (pack_length²). Actual attention work is Σᵢ sᵢ²
over the real sub-sequence lengths.

New helper accumulate_flops_metadata() in flop_utils.py extracts the
real sub-seq lengths from cu_seqlens (preferring cu_seqlens_unpadded
when present) and feeds Σᵢ sᵢ² into the existing seqlen_squared_sum
accumulator from #3529. Falls back to BSHD mbs * seq_len² when no
cu_seqlens is provided — bit-exact identical to legacy on dense
pretraining and non-packed paths.

Wired into gpt_step, vlm_step, qwen3_vl_step, and qwen3_omni_step.

Verified on cw-dfw (same seed, same data, same iter times, identical
loss values across paired runs — only the reported TFLOPS differs):
- qwen3_8b_sft  seq=2048: baseline 162.6 vs fix 155.8 TFLOP/s/GPU (+4%)
- qwen3_8b_sft  seq=4096: baseline 339.9 vs fix 156.7 TFLOP/s/GPU (+117%)
- qwen35_vl_9b_sft     : baseline 261.6 vs fix  88.9 TFLOP/s/GPU (+194%)

The seq=2048→4096 pair on the same LLM recipe is the cleanest
demonstration: the fix is near-flat (155.8 vs 156.7 — attention work
is determined by per-sample lengths, not pack length) while the
baseline doubles because its pack_length² scales quadratically.

9 new unit tests in test_flop_utils.py::TestAccumulateFlopsMetadata
cover the BSHD fallback, THD with cu_seqlens, padded cu_seqlens via
cu_seqlens_argmin, cu_seqlens_unpadded precedence, additive
accumulation, and the regression headline (32-sample pack → 32x
smaller attention work than BSHD approximation).

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx cuichenx added bug Something isn't working area:perf Performance optimizations and benchmarking area:training Training loop, callbacks, and runtime integration needs-review PR is ready for code review and waiting on a reviewer labels May 15, 2026
@yaoyu-33 yaoyu-33 removed the area:training Training loop, callbacks, and runtime integration label May 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:perf Performance optimizations and benchmarking bug Something isn't working needs-review PR is ready for code review and waiting on a reviewer

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants