[training, perf] fix: THD-aware FLOPS via cu_seqlens (Σᵢ sᵢ²)#3839
Open
cuichenx wants to merge 1 commit into
Open
[training, perf] fix: THD-aware FLOPS via cu_seqlens (Σᵢ sᵢ²)#3839cuichenx wants to merge 1 commit into
cuichenx wants to merge 1 commit into
Conversation
Packed THD training (offline-packed LLM SFT and VLM in-batch packing) over-counts attention FLOPS by treating the whole pack as one length- seq_length sequence (pack_length²). Actual attention work is Σᵢ sᵢ² over the real sub-sequence lengths. New helper accumulate_flops_metadata() in flop_utils.py extracts the real sub-seq lengths from cu_seqlens (preferring cu_seqlens_unpadded when present) and feeds Σᵢ sᵢ² into the existing seqlen_squared_sum accumulator from #3529. Falls back to BSHD mbs * seq_len² when no cu_seqlens is provided — bit-exact identical to legacy on dense pretraining and non-packed paths. Wired into gpt_step, vlm_step, qwen3_vl_step, and qwen3_omni_step. Verified on cw-dfw (same seed, same data, same iter times, identical loss values across paired runs — only the reported TFLOPS differs): - qwen3_8b_sft seq=2048: baseline 162.6 vs fix 155.8 TFLOP/s/GPU (+4%) - qwen3_8b_sft seq=4096: baseline 339.9 vs fix 156.7 TFLOP/s/GPU (+117%) - qwen35_vl_9b_sft : baseline 261.6 vs fix 88.9 TFLOP/s/GPU (+194%) The seq=2048→4096 pair on the same LLM recipe is the cleanest demonstration: the fix is near-flat (155.8 vs 156.7 — attention work is determined by per-sample lengths, not pack length) while the baseline doubles because its pack_length² scales quadratically. 9 new unit tests in test_flop_utils.py::TestAccumulateFlopsMetadata cover the BSHD fallback, THD with cu_seqlens, padded cu_seqlens via cu_seqlens_argmin, cu_seqlens_unpadded precedence, additive accumulation, and the regression headline (32-sample pack → 32x smaller attention work than BSHD approximation). Signed-off-by: Chen Cui <chcui@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The current FLOPS calculator treats every batch as BSHD with the attention term scaling as
seq_length²(orpack_length²after #3529). For THD packed training — both offline-packed LLM SFT (viaPackedSequenceSpecs) and VLM in-batch packing — the actual attention work isΣᵢ sᵢ²over the sub-sequence lengths within each pack, notpack_length². This PR threadscu_seqlensfrom the dataloader into the FLOPS accumulator so the attention term reports the true work.The fix
New helper
accumulate_flops_metadata()insrc/megatron/bridge/training/utils/flop_utils.py:seqlen_sum(linear) andseqlen_sq_sum(quadratic, attention-term driver).cu_seqlens(preferringcu_seqlens_unpaddedif present) is available, derivesseqlen_sq_sum = Σᵢ sᵢ²over real sub-sequences; truncates the padded tail viacu_seqlens_argmin.mbs * seq_len²when no cu_seqlens is provided (dense pretraining / non-packed paths) — bit-exact identical to pre-fix behavior.vision_patchesfrom optional image/video grid tensors.Wired into four step functions (consumed by the existing
train.pyreset/read flow merged in #3529):src/megatron/bridge/training/gpt_step.py— LLM, offline THD viaPackedSequenceSpecssrc/megatron/bridge/training/vlm_step.py— legacy VLM in-batch packingsrc/megatron/bridge/models/qwen_vl/qwen3_vl_step.py— Qwen3-VL family in-batch packingsrc/megatron/bridge/models/qwen_omni/qwen3_omni_step.py— Qwen3-Omni (no packing today; calls helper for vision-patch tracking and BSHD seqlen²)Why
For a packed batch of length
L_packholdingNreal sub-sequences of lengthss₁ … s_N:L_pack²Σᵢ sᵢ²Equality only when
N == 1(a single sub-seq fills the pack). For typical packed SFT or in-batch VLM packing,Σᵢ sᵢ² ≪ L_pack²because cross-sub-seq attention does not happen.Verification
Two identical workspaces on cw-dfw — only the 5 source files differ. Same model, same dataset, same seed (
rng.seed=42), same parallelism. Identical loss values and iteration times across paired runs confirm only the reported FLOPS changes.Means computed over 26 post-warmup iters (iters 5-30).
The seq=2048 vs seq=4096 pair on the same LLM recipe is the cleanest demonstration: the fix reports near-identical TFLOP/s/GPU (155.8 vs 156.7) because actual attention work is determined by per-sample lengths, not pack length. The baseline doubles (162.6 → 339.9) because its
pack_length²formula scales quadratically with the pack size while the actual work doesn't — a clear, growing over-count.The VLM case shows the largest over-count because in-batch packing concentrates many short multimodal samples into each 4K pack, where
pack_length²is dramatically larger thanΣᵢ sᵢ².Sanity checks (per run, in log header):
Tests
9 new unit tests in
tests/unit_tests/training/utils/test_flop_utils.py::TestAccumulateFlopsMetadatacover:mbs * seq_len²Σᵢ sᵢ²cu_seqlens_argmintruncationcu_seqlens_unpaddedprecedence over padded variantFull 62-test
test_flop_utils.pysuite passes (53 pre-existing + 9 new).Out of scope
llava_step.pyandaudio_lm_step.pystill fall back to legacycfg.model.seq_length(helper not wired). Easy follow-up.hybrid_flopsin flop_utils) still passes scalarseq_len, dropping quadratic accuracy for hybrid attention layers — independent of this PR.mbs * tokens.shape[1], per-CP-rank) and quadratic (Σ sᵢ² from full cu_seqlens, global) accumulator scale. Not introduced here; not exercised by the verification runs (CP=1).Labels
bug·area:perf·area:training·needs-review