Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
202 commits
Select commit Hold shift + click to select a range
4add1dc
nous changes
jquesnelle Aug 28, 2025
92352f0
add grpo :)
dmahan93 Sep 26, 2025
a8a2049
working version
dmahan93 Oct 3, 2025
de041d8
add config for legacy checkpoint loading
dmahan93 Oct 3, 2025
7fcd8aa
nous changes
jquesnelle Aug 28, 2025
1e44011
don't require a tokenizer or sequence lengths
jquesnelle Oct 4, 2025
9878802
pass position ids through to llama3
jquesnelle Oct 4, 2025
abded1e
hack: fix compile on blackwell
jquesnelle Oct 4, 2025
ae881c8
position ids for deepseek (need to figure out sft vs. pretrain)
jquesnelle Oct 4, 2025
5445dea
only pass sequence_lengths to attention init
jquesnelle Oct 4, 2025
13d201c
merge models into one modeling arch since the differences are qkv bia…
dmahan93 Oct 6, 2025
d7a2bc8
Merge remote-tracking branch 'refs/remotes/origin/dev-updated-again' …
dmahan93 Oct 8, 2025
47abb43
initial commit of qwen3-next
jquesnelle Oct 9, 2025
c013f69
fixes from upstream changes
dmahan93 Oct 9, 2025
62ca23f
Merge pull request #7 from NousResearch/add-grpo
dmahan93 Oct 9, 2025
25dc331
fix gating and activation
jquesnelle Oct 9, 2025
44b99ad
add inference logp IS
dmahan93 Oct 16, 2025
eb1a501
- Took Verl's IS implementation so we can get sequence IS
dmahan93 Oct 17, 2025
03d2da7
Merge pull request #9 from NousResearch/add-inference-logp
dmahan93 Oct 20, 2025
33591ba
add position_ids to qwen3-next
jquesnelle Oct 21, 2025
797db5f
guard import of causal_conv1d and fla
jquesnelle Oct 21, 2025
6786199
Merge pull request #10 from NousResearch/q3n
jquesnelle Oct 21, 2025
d91ee11
add seed-oss support
jquesnelle Oct 21, 2025
139f275
Merge pull request #11 from NousResearch/seed-oss
jquesnelle Oct 21, 2025
5817990
Merge remote-tracking branch 'pytorch/main' into dev-updated-again
jquesnelle Oct 23, 2025
9000e20
nanoset fixes
jquesnelle Oct 24, 2025
09b271a
Enhance Trainer class to accumulate tokens for MFU and throughput com…
ighoshsubho Oct 28, 2025
9b06e67
fix qwen3 moe init
jquesnelle Oct 29, 2025
683e90f
Merge pull request #12 from NousResearch/fix/mfu_cal
dmahan93 Oct 29, 2025
d58345c
add Qwen3 10B-A1B
jquesnelle Oct 29, 2025
be55e2e
fix qwen3-next linear layer causal leakage
jquesnelle Nov 1, 2025
f9e5ad6
add initial deepep and fused all-to-all
Nov 3, 2025
0f4a562
- add mixed precision usage
dmahan93 Nov 3, 2025
0e6acb3
add unit tests for deepep and fused all-to-all
Nov 3, 2025
a3c236e
- removed some old vars
dmahan93 Nov 3, 2025
0c9c4c0
remove use_cuda_num_token_per_expert from fused all-to-all, and fix d…
Nov 4, 2025
2cbee70
add explicit bfloat16 dtype in sglang launcher
dmahan93 Nov 4, 2025
fa525fc
fix missing routed_prob in GroupedExperts.forward()
Nov 4, 2025
7e2089c
fix tokens_per_expert on wrong device
Nov 4, 2025
0b8729f
Merge pull request #14 from NousResearch/fix-mixed-precision
dmahan93 Nov 5, 2025
7f4895e
fix self.cuda_dtoh_stream is none in deepep manager
Nov 5, 2025
d81812f
fix moe block return None outputs
Nov 5, 2025
dc4af3e
fix
Nov 6, 2025
425b1a9
support optional deepep
Nov 6, 2025
66cbd92
add vllm integration for torch 2.9 support
dmahan93 Nov 7, 2025
d3ae956
Merge pull request #15 from NousResearch/add-vllm-server
dmahan93 Nov 7, 2025
9253922
fix issues from dev-updated-again
dmahan93 Nov 7, 2025
7af8bf9
Merge pull request #16 from NousResearch/fix-grpo-on-dev
dmahan93 Nov 7, 2025
1adfdbe
initial test of PTX loss
dmahan93 Nov 8, 2025
b628078
accumulate ptx_loss across DP/grad accum
dmahan93 Nov 8, 2025
88d8f85
sneaking in dynamic batch size fix for num_microbatches...
dmahan93 Nov 8, 2025
94d6abd
squash nous changes
jquesnelle Nov 11, 2025
cc64dbb
add preprocess data script
jquesnelle Nov 12, 2025
8de4a45
add benchmarking deepep and small fixes
Nov 13, 2025
f56546f
add optimal deepep config + auto benchmarking deepep
Nov 13, 2025
b70afef
muon/dion
jquesnelle Nov 14, 2025
10fcf29
clear unsued tensors in fused deepep, unblock lbs=6
Nov 14, 2025
f874271
add memory debugging tools
Nov 17, 2025
db820d9
finish up ptx pr...
dmahan93 Nov 17, 2025
6f47022
Merge remote-tracking branch 'origin/dev-updated-again' into add-ptx-…
dmahan93 Nov 17, 2025
699b214
Merge pull request #17 from NousResearch/add-ptx-loss-to-rl
dmahan93 Nov 19, 2025
434a2b9
Add packed memmap dataset (#18)
dmahan93 Nov 19, 2025
7273bce
clean
xrsrke Nov 21, 2025
8703ffc
remove uncessary code, and add deepep config example
xrsrke Nov 24, 2025
f564ede
clean naming
xrsrke Nov 24, 2025
366776a
add readme for new expert parallelism
xrsrke Nov 24, 2025
57de083
16 nodes is 128 gpus
xrsrke Nov 24, 2025
5d4bdd0
update
xrsrke Nov 24, 2025
e770174
add Qwen3-30B-A3B totla tokens
xrsrke Nov 24, 2025
52ec741
log expert choices to wandb
jquesnelle Nov 26, 2025
1f45362
sft entrypoint
hjc-puro Nov 28, 2025
bc46b87
readme
hjc-puro Nov 28, 2025
1c0ee7c
readme
hjc-puro Nov 28, 2025
1a8978e
flexattn note
hjc-puro Nov 30, 2025
58f2722
wording
hjc-puro Nov 30, 2025
d836de3
fix cuda stream sync for deepep, and add scale expert outputs
xrsrke Dec 1, 2025
06bbc91
hacky patches to fix group_gemm's wrong grads
xrsrke Dec 2, 2025
a60f058
patch the fused_indices_converter, but commen tout the code
xrsrke Dec 4, 2025
495dc7d
add fused silu+expert multiplication
xrsrke Dec 5, 2025
6879d51
Add moe rl (#22)
dmahan93 Dec 5, 2025
e87d906
add fused expert multiplication+deepep, but has a bug not calling the…
xrsrke Dec 5, 2025
f26a0de
disable grad clip when
jquesnelle Dec 6, 2025
a0eb876
support using muon for moe expert and routing params
jquesnelle Dec 6, 2025
61a4020
disable torch.compile for muon-adamw
jquesnelle Dec 6, 2025
1380887
adding fused weighted scatter add, and deepep's comm stream sync
xrsrke Dec 8, 2025
a05ad08
add fused weighted silu activation for moe, and sanity check deepep's…
xrsrke Dec 8, 2025
e792ba2
add fused weighted silu kernel
xrsrke Dec 8, 2025
71ea04e
fixing missing expert-probability multiplication when both fused_weig…
xrsrke Dec 8, 2025
f50e967
Merge pull request #20 from NousResearch/log-expert-routing
jquesnelle Dec 9, 2025
c5bcc6f
fixes for qwen3-next
jquesnelle Dec 11, 2025
7cb7cc7
add config check for fullgraph
jquesnelle Dec 11, 2025
7a5f9f9
Merge pull request #24 from NousResearch/q3n-updates
jquesnelle Dec 11, 2025
06f0e08
Merge pull request #23 from NousResearch/dion
jquesnelle Dec 14, 2025
ac731e5
Merge pull request #21 from NousResearch/sft-entry
hjc-puro Dec 14, 2025
f974a5e
Update README.md
hjc-puro Dec 18, 2025
e7448ff
add random seq packing
hjc-puro Dec 18, 2025
545c051
Merge branch 'dev-updated-again' into phuc/add_deepep
xrsrke Dec 18, 2025
e97c6f7
add n_layers argument to gpt oss moe init
jquesnelle Dec 20, 2025
6720c3e
Merge pull request #26 from NousResearch/gpt-oss-layers-init
jquesnelle Dec 20, 2025
50383a3
add position_ids argument to Qwen3 and GPT-OSS
jquesnelle Dec 21, 2025
a9aaf81
Merge pull request #27 from NousResearch/add-missing-position-ids
jquesnelle Dec 21, 2025
3b9cce1
fix breaks due to git pull
xrsrke Dec 22, 2025
0f8becc
Merge pull request #25 from NousResearch/preprocess
jquesnelle Dec 23, 2025
c19494c
fix qwen3-next hf model conversion
jquesnelle Dec 29, 2025
5f2c1cd
fix qwen3-next w/ muon
jquesnelle Dec 30, 2025
a50b147
support training from qwen3 moe checkpoints with preprocessed datasets
jquesnelle Jan 5, 2026
9349db7
fixes for generation script
jquesnelle Jan 6, 2026
e22b7c3
fix preprocess when not saving (dry run)
jquesnelle Jan 6, 2026
d4a2956
add kimi linear
jquesnelle Jan 6, 2026
e928b1e
add extra timing metrics to diagnose issues (#28)
dmahan93 Jan 6, 2026
9b307fb
fix some breaks due to pulled changes
xrsrke Jan 8, 2026
64cfb6e
Merge branch 'dev-updated-again' into phuc/add_deepep
xrsrke Jan 8, 2026
3f7188b
move the patches fix for grouped_gemmm graient issue to PrimusTurboDe…
xrsrke Jan 8, 2026
a5f5732
clean up
xrsrke Jan 8, 2026
86bf4f2
convert from hf incrementally (#29)
dmahan93 Jan 8, 2026
c3262b5
remove the deprecated training.debug_moe_force_load_balance
xrsrke Jan 8, 2026
7bce1fd
update
xrsrke Jan 8, 2026
4c13048
remove uncessary config, and move router probs to the original dtype
xrsrke Jan 8, 2026
7f171fc
update kimi linear configs
jquesnelle Jan 9, 2026
775a093
Merge branch 'dev-updated-again' into phuc/add_deepep
jquesnelle Jan 9, 2026
aebb3a5
clean
xrsrke Jan 9, 2026
1ab3be4
minor edit in readme
xrsrke Jan 9, 2026
c3f38dd
Merge pull request #19 from NousResearch/phuc/add_deepep
xrsrke Jan 9, 2026
4f5a528
fix oom when save checkpoints
xrsrke Jan 13, 2026
6dd7117
Merge pull request #31 from NousResearch/fix/checkpoint-oom-high-memo…
xrsrke Jan 13, 2026
9c053e7
Enhance convert_to_hf with debug flag and key sorting
dmahan93 Jan 14, 2026
0bf454f
update sft preprocess script to support non packing, fix implicit def…
jquesnelle Jan 14, 2026
9a5f6cf
rerun precommit
xrsrke Jan 14, 2026
d634ef9
Merge pull request #34 from NousResearch/phuc/rerun_precommit
xrsrke Jan 14, 2026
90940da
add conditional import for deepep
xrsrke Jan 14, 2026
b40512e
add DeepEP installation reference to README
xrsrke Jan 14, 2026
ccfa685
Merge pull request #35 from NousResearch/phuc/fix_deepep_import
xrsrke Jan 14, 2026
7cd9e54
add shared expert gate
xrsrke Jan 14, 2026
ba889a8
refacot
xrsrke Jan 14, 2026
544115e
Merge pull request #36 from NousResearch/phuc/fix_add_shared_experts_…
xrsrke Jan 14, 2026
64e7b10
overlap reading from disk and loading to gpu when doing incremental c…
jquesnelle Jan 15, 2026
9d3aace
move first step save to before training step
jquesnelle Jan 15, 2026
30014ef
remove debug prints
jquesnelle Jan 15, 2026
1d47217
fsdp+ep in muon
jquesnelle Jan 15, 2026
f99e14b
remove incorrect batch size print
jquesnelle Jan 15, 2026
3bbe7ac
fix handling expert_bias parameter when converting to/from hf for qwen3
jquesnelle Jan 15, 2026
59e83f0
pass through use_deepep for dsv3
jquesnelle Jan 15, 2026
6ee08eb
DSv3: YaRN now controlled by rope_factor and not conditional on seque…
jquesnelle Jan 15, 2026
65add62
lora update (#32)
dmahan93 Jan 15, 2026
4b6258a
fix convergence issues from recent updates (#38)
dmahan93 Jan 15, 2026
258773f
fix DeepEP import
jquesnelle Jan 16, 2026
ba9d888
- fix missing peft configs (#41)
dmahan93 Jan 16, 2026
2490a1e
Fix my merge v2 (#42)
dmahan93 Jan 16, 2026
a4fcee7
Change linear check to nn.Linear in _apply method in rowwiseparallel
dmahan93 Jan 16, 2026
e4ad6ce
fix nan issue due to fsdp offloading
xrsrke Jan 16, 2026
5a6e37b
add comment
xrsrke Jan 16, 2026
82cd154
Merge pull request #43 from NousResearch/phuc/fix_cpu_offloading_nan_bug
jquesnelle Jan 19, 2026
b8cb8e6
fix sft preprocess script when previewing kimi
jquesnelle Jan 19, 2026
12fb09b
add --push-to-hub to preprocessing script
jquesnelle Jan 19, 2026
2a642d0
Update README with libnvshmem_host.so troubleshooting
dmahan93 Jan 20, 2026
8e5f859
Merge pull request #44 from NousResearch/deepep-install-readme-updates
jquesnelle Jan 20, 2026
3263b15
Update GRPO.md
dmahan93 Jan 23, 2026
a8ac852
Merge branch 'dev-updated-again' into upstream-2026-24-01
jquesnelle Jan 26, 2026
6d35673
Merge branch 'dev-updated-again' into upstream-2026-24-01
jquesnelle Jan 27, 2026
865ebb8
context parallel support in dsv3 and qwen3
jquesnelle Jan 28, 2026
2ad47cb
fast path for initing bfloat16 params on cpu
jquesnelle Jan 21, 2026
81e54a4
add reference for init scheme
jquesnelle Jan 22, 2026
f04236d
overlapped cpu offload muon
jquesnelle Jan 23, 2026
e7ccfdc
merge fixups
jquesnelle Jan 29, 2026
98f53ee
merge fixups
jquesnelle Jan 30, 2026
668f23e
Add memory tracking and BF16 optimizer state features with Kimi K2 co…
xrsrke Jan 31, 2026
4071454
Add NaN tracker config, FSDP prefetch control, and nvidia-smi memory …
xrsrke Jan 31, 2026
375762b
Add partial resharding support (fsdp_reshard_after_forward accepts int)
xrsrke Jan 31, 2026
0a06429
Add device mesh visualizer for distributed training debugging
xrsrke Jan 31, 2026
fe8d1f0
add option to filter data when preprocessing by a specific string
jquesnelle Feb 1, 2026
f50b804
add kimi_k2_sft
jquesnelle Feb 1, 2026
ed6b753
fix wrong arg used for --push-to-hub
jquesnelle Feb 1, 2026
7f6f3a3
fix attention args, add kimi_k2_ep64_cp1_seq24k_lbs1 160 tps config
xrsrke Feb 4, 2026
c3a14a1
Merge branch 'upstream-2026-24-01' of https://github.com/NousResearch…
xrsrke Feb 4, 2026
f03ff7b
Merge branch 'main' into upstream-2026-10-02
jquesnelle Feb 10, 2026
12ccfd7
use empty_like instead of copying, return to original data type
jquesnelle Feb 10, 2026
9c6a8ac
move lists of tensors to the right device
jquesnelle Feb 10, 2026
37fab69
remove muon parameter name prints
jquesnelle Feb 17, 2026
d84164b
wire up state_dtype through muon
jquesnelle Feb 17, 2026
9370ebb
qwen 3.5 support
jquesnelle Feb 18, 2026
ffdab11
add target tokens to nanoset
jquesnelle Feb 18, 2026
62127c5
allow sdpa for qwen3-next when linear attention off
jquesnelle Feb 18, 2026
a79a842
add GLM 4.7 and 5 configs
jquesnelle Feb 18, 2026
2d03909
add Muon Split
jquesnelle Feb 18, 2026
4a8b980
pass-through dataloader args
jquesnelle Feb 24, 2026
b3d313f
Port LLEP implementation onto fresh upstream-2026-10-02 base
xrsrke Mar 4, 2026
6090b62
Wire LLEP TOML config overrides to MoEArgs in DeepSeekV3
xrsrke Mar 4, 2026
295e7fd
Add 3B benchmark config and missing multinode LLEP TOML configs
xrsrke Mar 4, 2026
d3ba024
Remove redundant LLEP scripts, docs, configs, and model flavors
xrsrke Mar 4, 2026
5219c3d
Restore upstream scripts/loss_compare.py (not LLEP-specific)
xrsrke Mar 4, 2026
2acc5d4
Fix LLEP autotune: get EP group from DTensor mesh, use use_llep flag
xrsrke Mar 4, 2026
f799a1d
comment out optimizer step log
jquesnelle Mar 4, 2026
b8e262e
Add adaptive DeepEP+LLEP switching for expert parallelism
xrsrke Mar 5, 2026
31cf28b
Update LLEP benchmark config to 9.5B model and refresh README data
xrsrke Mar 5, 2026
5d42125
Remove DeepEPLLEPMoE, reuse DeepEPMoE for both deepep and deepep_llep
xrsrke Mar 5, 2026
8a7ea9c
Add DeepEP+LLEP benchmark results to LLEP README
xrsrke Mar 5, 2026
46b0cf1
qwen 3.5 fixes
jquesnelle Mar 6, 2026
7fda10d
Remove unused fused_silu_gate kernel and its tests
xrsrke Mar 6, 2026
7c63d16
Reorganize LLEP files into llep/ package
xrsrke Mar 6, 2026
7a096af
Remove duplicate fast_init_* functions, use upstream trunc_normal_/no…
xrsrke Mar 6, 2026
198ff4f
Merge pull request #55 from NousResearch/phuc/llep_optimized_v2
xrsrke Mar 6, 2026
0dc39ee
Fix failed tests on upstream-2026-10-02
xrsrke Mar 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update LLEP benchmark config to 9.5B model and refresh README data
Replace the 3B debugmodel_ep8_llep_3b (too small for LLEP to help)
with the 9.5B debugmodel_ep8_llep (dim=2048, moe_inter_dim=1536,
16 layers, lbs=8) that stresses GPU memory and shows LLEP's benefit.

Benchmark results on 8xB200 (steps 5-20):
- Speed: +10.9% mean TPS (26,370 vs 23,780)
- Memory: 7 GiB spread vs 42 GiB, max 82% vs 97%
- Without LLEP at lbs=10: OOM. With LLEP: runs fine.
- Loss correctness: <0.001 diff by step 130
  • Loading branch information
xrsrke committed Mar 5, 2026
commit 31cf28b189c385d4c19e6957b0ea7d7048b4ccfb
154 changes: 72 additions & 82 deletions docs/llep.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,19 @@ A ready-made config for inspecting LLEP distribution on 8 GPUs:
```bash
torchrun --nproc_per_node=8 -m torchtitan.train \
--job.config_file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
--training.steps 3
```

Or with verbose per-step distribution logging:

```bash
LLEP_DEBUG=1 torchrun --nproc_per_node=8 -m torchtitan.train \
--job.config_file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
--training.steps 3 \
2>&1 | tee /tmp/llep_distribution_logs.txt
```

This config uses `debugmodel_ep8_llep` (64 experts, top_k=8, EP=8) with `min_tokens_per_gemm=1` and `adaptive_threshold=0.0` so LLEP always triggers, even at small debug scale. See `train_configs/debug_model_ep8_llep.toml`.
This config uses `debugmodel_ep8_llep` (64 experts, top_k=8, EP=8) with `min_tokens_per_gemm=1` and `adaptive_threshold=0.0` so LLEP always triggers. See `train_configs/debug_model_ep8_llep.toml`.

## Benchmark: LLEP vs Standard EP

Expand All @@ -156,9 +165,7 @@ The `debugmodel_ep8_llep` flavor is a 9.5B-parameter MoE model designed for sing

| Parameter | Value |
|-----------|-------|
| Total params | 9.5B (8.9 GB bf16) |
| MoE expert params | 9.1B (96%) |
| Active params/token | 1.6B (top_k=8 of 64 experts) |
| Total params | 9.5B |
| dim | 2048 |
| inter_dim | 8192 |
| moe_inter_dim | 1536 |
Expand All @@ -167,7 +174,7 @@ The `debugmodel_ep8_llep` flavor is a 9.5B-parameter MoE model designed for sing
| top_k | 8 |
| EP | 8 (8 local experts/GPU) |

Training config: `lbs=6, seq_len=4096, AdamW, no compile, no activation checkpointing`.
Training config: `lbs=8, seq_len=4096, AdamW, no compile, no activation checkpointing`.

### Reproducing

Expand All @@ -177,85 +184,74 @@ cd torchtitan
# WITH LLEP (20 steps)
torchrun --nproc_per_node=8 --rdzv-endpoint=localhost:29500 \
-m torchtitan.train \
--job.config-file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
--training.steps 20 --compile.no-enable \
--job.config_file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
--training.steps 20 \
2>&1 | tee llep_with_llep.txt

# WITHOUT LLEP (20 steps, same model)
torchrun --nproc_per_node=8 --rdzv-endpoint=localhost:29500 \
-m torchtitan.train \
--job.config-file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
--training.steps 20 --compile.no-enable --llep.enabled=False \
--job.config_file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
--training.steps 20 --llep.enabled=False \
2>&1 | tee llep_no_llep.txt
```

To enable verbose per-step distribution logging (shows BEFORE/AFTER imbalance, send matrix, weight transfers):

```bash
torchrun --nproc_per_node=8 --rdzv-endpoint=localhost:29500 \
-m torchtitan.train \
--job.config-file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
--training.steps 3 --compile.no-enable --llep.verbose=True \
2>&1 | tee llep_verbose_logs.txt
```

### Results (8xB200, 20 steps)

**Speed** (steps 5-20 average, excluding warmup):

| | With LLEP | Without LLEP | Delta |
|---|---|---|---|
| Mean TPS | ~16,270 | ~15,120 | **+7.6%** |
| Mean MFU | 8.2% | 7.6% | +7.9% |
| Mean TPS | ~26,370 | ~23,780 | **+10.9%** |
| Mean MFU | 11.4% | 10.3% | +10.7% |

**Memory** (per-GPU at step 20):

| | With LLEP | Without LLEP |
|---|---|---|
| Active range | 105-107 GiB (59-60%) | 93-124 GiB (52-**69%**) |
| Reserved range | 116-120 GiB (65-67%) | 143-173 GiB (80-**97%**) |
| Max reserved | 120 GiB | **173 GiB** (near OOM) |
| Spread (reserved) | ~4 GiB | **30 GiB** |
| Memory range | 140-147 GiB (78-82%) | 132-173 GiB (74-**97%**) |
| Max memory | 147 GiB | **173 GiB** (near OOM) |
| Spread | ~7 GiB | **42 GiB** |

Without LLEP, the most-loaded GPU hits 97% reserved memory (near OOM) while the least-loaded sits at 80%. LLEP keeps all GPUs in a tight 65-67% band. LLEP is both faster (less straggler waiting from load imbalance) and safer (no GPU near OOM).
Without LLEP, the most-loaded GPU hits 97% memory (near OOM on 178 GiB B200) while the least-loaded sits at 74%. LLEP keeps all GPUs in a tight 78-82% band. LLEP is both faster (less straggler waiting from load imbalance) and safer (no GPU near OOM).

### Per-GPU Memory Breakdown (step 5)

Detailed per-GPU view showing the memory imbalance that LLEP eliminates:

**With LLEP** — all GPUs balanced within a 3 GiB band:

| GPU | Active (GiB) | Active % | Reserved (GiB) | Reserved % | TPS |
|-----|-------------|----------|----------------|------------|-----|
| 0 | 104.29 | 58.5% | 115.74 | 64.9% | 16,088 |
| 1 | 104.40 | 58.5% | 120.20 | 67.4% | 16,078 |
| 2 | 104.98 | 58.9% | 116.29 | 65.2% | 16,087 |
| 3 | 105.24 | 59.0% | 115.59 | 64.8% | 15,978 |
| 4 | 105.48 | 59.1% | 116.83 | 65.5% | 16,082 |
| 5 | 105.64 | 59.2% | 118.05 | 66.2% | 15,975 |
| 6 | 107.10 | 60.1% | 116.87 | 65.5% | 16,082 |
| 7 | 107.45 | 60.2% | 118.10 | 66.2% | 15,955 |
| **Spread** | **3.2** | | **4.6** | | |

**Without LLEP** — wildly imbalanced, one GPU near OOM:

| GPU | Active (GiB) | Active % | Reserved (GiB) | Reserved % | TPS |
|-----|-------------|----------|----------------|------------|-----|
| 0 | 100.20 | 56.2% | 131.74 | 73.9% | 15,636 |
| 1 | 104.78 | 58.7% | 148.46 | 83.2% | 15,655 |
| 2 | 110.80 | 62.1% | 147.52 | 82.7% | 15,655 |
| 3 | 118.87 | **66.6%** | 165.93 | **93.0%** | 15,640 |
| 4 | 123.88 | **69.5%** | 137.88 | 77.3% | 15,633 |
| 5 | 91.15 | **51.1%** | 132.84 | 74.5% | 15,616 |
| 6 | 93.18 | 52.2% | 152.99 | 85.8% | 15,600 |
| 7 | 99.20 | 55.6% | 149.78 | 84.0% | 15,583 |
| **Spread** | **32.7** | | **33.2** | | |
**With LLEP** — all GPUs balanced within a 7 GiB band:

| GPU | Memory (GiB) | Memory % | TPS |
|-----|-------------|----------|-----|
| 0 | 144.62 | 81.1% | 25,400 |
| 1 | 144.37 | 80.9% | 25,464 |
| 2 | 143.46 | 80.4% | 25,491 |
| 3 | 144.57 | 81.1% | 25,398 |
| 4 | 143.00 | 80.2% | 25,458 |
| 5 | 141.26 | 79.2% | 25,349 |
| 6 | 139.89 | 78.4% | 25,510 |
| 7 | 146.60 | 82.2% | 25,477 |
| **Spread** | **6.7** | | |

**Without LLEP** — wildly imbalanced, two GPUs near OOM:

| GPU | Memory (GiB) | Memory % | TPS |
|-----|-------------|----------|-----|
| 0 | 159.46 | 89.4% | 18,936 |
| 1 | 144.33 | 80.9% | 18,908 |
| 2 | 172.27 | **96.6%** | 18,945 |
| 3 | 172.73 | **96.8%** | 18,955 |
| 4 | 166.89 | 93.6% | 18,953 |
| 5 | 165.37 | 92.7% | 18,907 |
| 6 | 155.88 | 87.4% | 18,949 |
| 7 | 145.89 | 81.8% | 18,878 |
| **Spread** | **28.4** | | |

Key observations:
- Without LLEP, GPU 3 reserves **165.9 GiB (93.0%)** of 178 GiB — one more imbalanced step away from OOM.
- GPU 5 is nearly idle at 51.1% active while GPU 4 is at 69.5% — an **18.4 percentage point** gap.
- LLEP compresses the active memory spread from **32.7 GiB to 3.2 GiB** (10x reduction).
- With LLEP every GPU runs at ~16,000+ TPS vs ~15,600 without — the straggler GPU drags everyone down.
- Without LLEP, GPU 2-3 hit **96.6-96.8%** of 178 GiB — one more imbalanced step away from OOM. At `lbs=10` this config OOMs without LLEP.
- GPU 7 is at 81.8% while GPU 3 is at 96.8% — a **15 percentage point** gap.
- LLEP compresses the memory spread from **28.4 GiB to 6.7 GiB** (4x reduction).
- With LLEP every GPU runs at ~25,400+ TPS vs ~18,900 without — **+34% faster** at step 5 when imbalance is worst.

To reproduce this comparison:

Expand All @@ -265,66 +261,60 @@ cd torchtitan
# 5-step memory comparison with LLEP
torchrun --nproc_per_node=8 --rdzv-endpoint=localhost:29500 \
-m torchtitan.train \
--job.config-file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
--training.steps 5 --compile.no-enable \
--job.config_file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
--training.steps 5 \
2>&1 | tee llep_memory_with_llep.txt

# 5-step memory comparison without LLEP
torchrun --nproc_per_node=8 --rdzv-endpoint=localhost:29500 \
-m torchtitan.train \
--job.config-file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
--training.steps 5 --compile.no-enable --llep.enabled=False \
--job.config_file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
--training.steps 5 --llep.enabled=False \
2>&1 | tee llep_memory_no_llep.txt

# Extract per-GPU memory at step 5
grep "step: 5" llep_memory_with_llep.txt
grep "step: 5" llep_memory_no_llep.txt
grep "step: 5" llep_memory_with_llep.txt
grep "step: 5" llep_memory_no_llep.txt
```

### Loss Correctness (LLEP vs Standard EP)

LLEP produces identical training loss to standard EP, confirming numerical correctness. Both runs use the same seed, weights, and data — only the dispatch/combine path differs.

**WandB**: [nous_research/llep_loss_comparison](https://wandb.ai/nous_research/llep_loss_comparison) — overlay both runs to see matching loss curves.
LLEP produces matching training loss to standard EP, confirming numerical correctness. Both runs use the same seed, weights, and data — only the dispatch/combine path differs.

| Step | With LLEP | Without LLEP | Diff |
|------|-----------|-------------|------|
| 10 | 4.4471 | 4.4062 | 0.041 |
| 30 | 3.1145 | 3.1137 | 0.001 |
| 50 | 2.9153 | 2.8987 | 0.017 |
| 80 | 2.7933 | 2.7947 | 0.001 |
| 100 | 2.7529 | 2.7575 | 0.005 |
| 130 | 2.7769 | 2.7873 | 0.010 |
| 10 | 4.2333 | 4.3096 | 0.076 |
| 30 | 3.0474 | 3.0273 | 0.020 |
| 50 | 2.8995 | 2.8876 | 0.012 |
| 80 | 2.7699 | 2.7721 | 0.002 |
| 100 | 2.7737 | 2.7767 | 0.003 |
| 130 | 2.7993 | 2.7989 | 0.000 |

To reproduce (8 GPUs, ~3 min each, logs to wandb):
To reproduce (8 GPUs, ~10 min each):

```bash
cd torchtitan

# WITH LLEP (130 steps, seed=42, wandb)
WANDB_PROJECT=llep_loss_comparison \
# WITH LLEP (130 steps, seed=42)
torchrun --nproc_per_node=8 --rdzv-endpoint=localhost:29500 \
-m torchtitan.train \
--job.config_file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
--training.steps 130 --debug.seed 42 --compile.no-enable \
--llep.enabled True --metrics.log_freq 1 --metrics.enable_wandb
--training.steps 130 --debug.seed=42 --metrics.log_freq=1

# WITHOUT LLEP (130 steps, same seed/config)
WANDB_PROJECT=llep_loss_comparison \
torchrun --nproc_per_node=8 --rdzv-endpoint=localhost:29501 \
-m torchtitan.train \
--job.config_file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
--training.steps 130 --debug.seed 42 --compile.no-enable \
--llep.enabled False --metrics.log_freq 1 --metrics.enable_wandb
--training.steps 130 --debug.seed=42 --metrics.log_freq=1 --llep.enabled=False
```

### Unit Tests

```bash
# LPT planning + SwiGLU FFN (5 tests, no GPU required)
# LPT planning + SwiGLU FFN (3 tests, no GPU required)
python -m pytest tests/unit_tests/test_llep.py -v

# Grouped MM, Triton kernels, numerical correctness (17 tests, 1 GPU)
# Triton kernels, numerical correctness (5 tests, 1 GPU)
python -m pytest tests/unit_tests/test_llep_correctness.py -v

# Hook-based flow (59 tests, requires >= 2 GPUs)
Expand Down
13 changes: 7 additions & 6 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@
v_head_dim=128,
mscale=0.70,
),
# ~1.75B model with 64 experts EP=8 for LLEP benchmarking
"debugmodel_ep8_llep_3b": DeepSeekV3ModelArgs(
# ~9.5B model with 64 experts EP=8 for LLEP benchmarking
# dim=2048, moe_inter_dim=1536, 16 layers (1 dense + 15 MoE)
"debugmodel_ep8_llep": DeepSeekV3ModelArgs(
vocab_size=2048,
dim=1024,
inter_dim=4096,
moe_inter_dim=768,
n_layers=12,
dim=2048,
inter_dim=8192,
moe_inter_dim=1536,
n_layers=16,
n_dense_layers=1,
n_heads=16,
moe_args=MoEArgs(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
# DeepSeek-V3 debug model: ~3B params, 64 experts, EP=8, top_k=8, LLEP
# 12 layers (1 dense + 11 MoE), dim=1024, moe_inter_dim=768
# Sized to fit on 8xB200 with lbs=6, seq_len=4096 on new upstream
# DeepSeek-V3 debug model: ~9.5B params, 64 experts, EP=8, top_k=8, LLEP
# 16 layers (1 dense + 15 MoE), dim=2048, moe_inter_dim=1536
# Sized to stress 8xB200 memory with lbs=6, seq_len=4096
#
# Run with LLEP:
# torchrun --nproc_per_node=8 -m torchtitan.train \
# --job.config-file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep_3b.toml \
# --training.steps 20 --compile.no-enable
# --job.config_file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
# --training.steps 20
#
# Run without LLEP (baseline):
# torchrun --nproc_per_node=8 -m torchtitan.train \
# --job.config-file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep_3b.toml \
# --training.steps 20 --compile.no-enable --llep.enabled=False
# --job.config_file torchtitan/models/deepseek_v3/train_configs/debug_model_ep8_llep.toml \
# --training.steps 20 --llep.enabled=False

[job]
dump_folder = "./outputs"
description = "DeepSeek-V3 3B debug EP=8 LLEP benchmark"
description = "DeepSeek-V3 9.5B debug EP=8 LLEP benchmark"
print_config = false

[profiling]
Expand All @@ -33,7 +33,7 @@ enable_wandb = false

[model]
name = "deepseek_v3"
flavor = "debugmodel_ep8_llep_3b"
flavor = "debugmodel_ep8_llep"
# test tokenizer, for debug purpose only
hf_assets_path = "./tests/assets/tokenizer"

Expand All @@ -49,7 +49,7 @@ decay_type = "linear"
min_lr_factor = 0.0

[training]
local_batch_size = 6
local_batch_size = 8
seq_len = 4096
max_norm = 1.0
steps = 5
Expand Down