Skip to content
Draft
Changes from 1 commit
Commits
Show all changes
150 commits
Select commit Hold shift + click to select a range
53028bb
Move padding
JacobHelwig Jan 17, 2026
d3c161d
Prompt lens
JacobHelwig Jan 17, 2026
c82c2fb
Tests
JacobHelwig Jan 15, 2026
59a0c72
Multiple 3D tensors test
JacobHelwig Jan 15, 2026
f93d5a1
Merge tests
JacobHelwig Jan 15, 2026
11b047d
init
JacobHelwig Jan 12, 2026
3d6a996
Init debug script
JacobHelwig Jan 12, 2026
9dedb55
Plan
JacobHelwig Jan 13, 2026
d91c787
Add top-k log probs
JacobHelwig Jan 14, 2026
a69ddd3
Stage-wise top-k
JacobHelwig Jan 15, 2026
d325faa
Distillation cfg
JacobHelwig Jan 16, 2026
0005747
Distillation losses
JacobHelwig Jan 16, 2026
b40577e
Re-factor distillation
JacobHelwig Jan 16, 2026
a39394f
RM unused
JacobHelwig Jan 17, 2026
5fdfdac
Working: Add full JSD/KL, rm FSDP cfg, routing by distillation loss type
JacobHelwig Jan 17, 2026
61710cf
rm example
JacobHelwig Jan 17, 2026
a3655d6
dp distillation cfg
JacobHelwig Jan 17, 2026
1d9d8d7
Fix clamping
JacobHelwig Jan 17, 2026
da3c2fe
Enable ref w distillation
JacobHelwig Jan 18, 2026
1492aff
Teacher model cfg
JacobHelwig Jan 18, 2026
f05b140
Legacy distillation
JacobHelwig Jan 18, 2026
99c8abf
Distillation cfg to actor_rollout_ref
JacobHelwig Jan 18, 2026
0db7f63
Clamping and pass distillation cfg instead of actor
JacobHelwig Jan 18, 2026
6c69e66
Ruff
JacobHelwig Jan 18, 2026
ad0a3e7
Distillation training script
JacobHelwig Jan 18, 2026
7bb9ef3
Update train script
JacobHelwig Jan 18, 2026
f55e542
Decouple distillation and ref configs
JacobHelwig Jan 18, 2026
da183f8
Distillation cfg validation
JacobHelwig Jan 18, 2026
27e682e
Loss settings in cfg
JacobHelwig Jan 18, 2026
15d0d0b
Use estimator
JacobHelwig Jan 18, 2026
4d1e998
Distillation_config->distillation
JacobHelwig Jan 18, 2026
76d7116
Distillatoin loss signatures
JacobHelwig Jan 18, 2026
03e5ec8
Running w no rm pad
JacobHelwig Jan 18, 2026
6d08421
Ulysses working
JacobHelwig Jan 18, 2026
24cf8ae
SP in script
JacobHelwig Jan 18, 2026
c6dfc28
Log eps
JacobHelwig Jan 18, 2026
05dac28
Add distillation tests
JacobHelwig Jan 18, 2026
4cb8e43
Reduce tests
JacobHelwig Jan 18, 2026
b9c7b86
Doc strings
JacobHelwig Jan 18, 2026
c65bc5e
Ruff
JacobHelwig Jan 18, 2026
442cbee
Not implemented
JacobHelwig Jan 18, 2026
73f818e
Clamp
JacobHelwig Jan 19, 2026
07efd60
Divergence note
JacobHelwig Jan 19, 2026
9598240
take abs and distillation loss inputs
JacobHelwig Jan 19, 2026
3f2be47
Take abs in name
JacobHelwig Jan 19, 2026
f6aad57
loss fix
JacobHelwig Jan 19, 2026
f0d11c8
null
JacobHelwig Jan 19, 2026
be17339
Take abs fix
JacobHelwig Jan 19, 2026
c5c635e
RM blank lines
JacobHelwig Jan 19, 2026
057de32
Long line and generate cfg
JacobHelwig Jan 19, 2026
a656305
rl dataset
JacobHelwig Jan 19, 2026
640ffc5
rm TODO
JacobHelwig Jan 19, 2026
548c0de
CI fixes
JacobHelwig Jan 19, 2026
89eb552
PC
JacobHelwig Jan 19, 2026
bda32f4
None cfg
JacobHelwig Jan 19, 2026
a5ad75d
Distillation
JacobHelwig Jan 19, 2026
0ebb0e0
cfg to test
JacobHelwig Jan 19, 2026
4751108
None config
JacobHelwig Jan 19, 2026
f80c01f
Dist cfg
JacobHelwig Jan 19, 2026
8ed4b68
Distillation enabled
JacobHelwig Jan 19, 2026
8f79ec0
Dist cfg test eng
JacobHelwig Jan 19, 2026
92155a6
Update generated cfg
JacobHelwig Jan 19, 2026
218493c
Entropy bonus
JacobHelwig Jan 19, 2026
130301f
Types
JacobHelwig Jan 19, 2026
aa9be1d
Ruff
JacobHelwig Jan 26, 2026
c5c007d
Generated PPO trainer
JacobHelwig Jan 26, 2026
9cb348b
RM take abs
JacobHelwig Jan 26, 2026
c085f20
Only teacher topk in utils
JacobHelwig Jan 26, 2026
aae5c15
Re-factor stages
JacobHelwig Jan 26, 2026
6544f7c
FSDP losses
JacobHelwig Jan 27, 2026
45a6ddb
Update training config
JacobHelwig Jan 27, 2026
2b9f75d
Ruff
JacobHelwig Jan 27, 2026
ec4e116
FSDP utils
JacobHelwig Jan 27, 2026
6f79f7c
Fix teacher logprobs for top-k
JacobHelwig Jan 27, 2026
ba4874a
Doc string
JacobHelwig Jan 27, 2026
dcb2050
Distllation loss range
JacobHelwig Jan 27, 2026
457460f
Mask bug fix
JacobHelwig Jan 27, 2026
4af7188
Disable clamp
JacobHelwig Jan 27, 2026
d42525b
Clamp distillation loss to 10
JacobHelwig Jan 27, 2026
777219e
Clamp probs
JacobHelwig Jan 27, 2026
e1e97ab
log prob clamping
JacobHelwig Jan 28, 2026
f78e059
Clipping defaults
JacobHelwig Jan 28, 2026
acfd386
Script
JacobHelwig Jan 28, 2026
a08d87b
k3 fixes
JacobHelwig Jan 28, 2026
b4aa550
Loss settings doc string
JacobHelwig Jan 29, 2026
730ddf7
Separate fields for log probs
JacobHelwig Jan 29, 2026
a7423bc
Policy loss default
JacobHelwig Jan 29, 2026
1dbbb1f
Generated config
JacobHelwig Jan 29, 2026
280a9b3
Fix kld
JacobHelwig Jan 30, 2026
ec34bfd
Generate cfg
JacobHelwig Feb 3, 2026
8fc2805
mindspeed
JacobHelwig Feb 3, 2026
b369f38
ms
JacobHelwig Feb 3, 2026
11a7018
Padding util for multi-dimensional tensors
JacobHelwig Feb 4, 2026
d5bd52b
RM slicing
JacobHelwig Feb 4, 2026
5ca2116
Ruff
JacobHelwig Feb 4, 2026
e9bf2d5
Megatron script
JacobHelwig Jan 22, 2026
c81ef6b
Dist cfg
JacobHelwig Jan 22, 2026
9c9e01a
Dist cfg pyclass
JacobHelwig Jan 22, 2026
1bb74fe
megatron transformer impl
JacobHelwig Jan 29, 2026
061bb6d
dev on megatron topk logprobs
JacobHelwig Jan 29, 2026
58cf25b
Init megatron losses
JacobHelwig Jan 29, 2026
69c0216
Update script
JacobHelwig Jan 30, 2026
b742142
Compute top-k megatron
JacobHelwig Jan 30, 2026
809a923
Working top-k megatron loss
JacobHelwig Jan 30, 2026
34e5312
VP SM
JacobHelwig Jan 30, 2026
a09c3ab
VP SM in VP topk KL
JacobHelwig Jan 30, 2026
813796c
Clean up script
JacobHelwig Jan 30, 2026
c0990a3
Comments and mass logging
JacobHelwig Jan 30, 2026
f1a6112
Handle nested logits
JacobHelwig Jan 30, 2026
5823553
Minimize mcore cfg
JacobHelwig Jan 30, 2026
49b7612
Correction to gradients of top-k KL loss
JacobHelwig Jan 30, 2026
6178aaf
Working script
JacobHelwig Jan 30, 2026
9a4d8c8
Working script
JacobHelwig Jan 30, 2026
8931c52
Clamping
JacobHelwig Jan 31, 2026
4ca9fcc
Fix the scatter add
JacobHelwig Jan 31, 2026
e926c03
Update comments
JacobHelwig Jan 31, 2026
52a273c
log softmax for stability
JacobHelwig Jan 31, 2026
b35a715
FSDP parity tests
JacobHelwig Jan 31, 2026
a98c841
Workflow
JacobHelwig Jan 31, 2026
5c50f79
Ruff
JacobHelwig Jan 31, 2026
709d537
TP in exp name
JacobHelwig Feb 1, 2026
6a66a0e
Train script
JacobHelwig Feb 1, 2026
b1c5f68
Fix megatron estimator distillation
JacobHelwig Feb 1, 2026
8be7b34
Generate megatron cfg
JacobHelwig Feb 1, 2026
ae106c8
Mindspeed distillation
JacobHelwig Feb 2, 2026
98bfc42
Ruff
JacobHelwig Feb 2, 2026
2f9cc16
Ruff
JacobHelwig Feb 4, 2026
b2331e2
Init run qwen codeforces
JacobHelwig Feb 1, 2026
8996495
code force analysis
JacobHelwig Feb 1, 2026
024b8e7
Patch runtime module
JacobHelwig Feb 1, 2026
ada3411
console logging
JacobHelwig Feb 1, 2026
e936cf9
Working codeforces script
JacobHelwig Feb 1, 2026
ac4623c
Add logging
JacobHelwig Feb 1, 2026
00499c5
Init codeforces OPD
JacobHelwig Feb 2, 2026
291302a
Update qwen codeforces
JacobHelwig Feb 2, 2026
0948830
Init mopd cfg
JacobHelwig Feb 1, 2026
20497e2
Multi-teacher configs
JacobHelwig Feb 2, 2026
10dcf24
Return teacher cfg by ID
JacobHelwig Feb 2, 2026
c3d4f44
Multi task knowledge acquisition
JacobHelwig Feb 2, 2026
f6b410e
Training running
JacobHelwig Feb 2, 2026
28d41d5
RM debugging device
JacobHelwig Feb 2, 2026
001c6fc
Domain sets
JacobHelwig Feb 3, 2026
ca72c18
Acquire knowledge in ray trainer
JacobHelwig Feb 3, 2026
e828a76
Update script
JacobHelwig Feb 3, 2026
ebed86b
Update script
JacobHelwig Feb 3, 2026
62a901d
Fix ref cfg prep
JacobHelwig Feb 3, 2026
4a234bc
Distillation loss config and refactor distillation loss
JacobHelwig Feb 4, 2026
2164f2f
Fix setting distillation loss settings
JacobHelwig Feb 4, 2026
45b894b
Teacher pool
JacobHelwig Feb 4, 2026
1858382
lora in script
JacobHelwig Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Tests
  • Loading branch information
JacobHelwig committed Feb 4, 2026
commit c82c2fba8189abcc50bdd6c3c2aae8e08e6ff661
86 changes: 86 additions & 0 deletions tests/test_protocol_v2_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,3 +1066,89 @@ def test_contiguous():
data_cont.consolidate()

tu.assert_tensordict_eq(data_cont, data)


def test_concat_nested_tensor_3d():
"""Test concatenating 3D nested tensors."""
# Create 3D nested tensors
# Each batch element has shape [4, N] where N varies
batch1_tensors = [
torch.arange(4).expand(4, 4), # [4, 4]
torch.arange(5).expand(4, 5), # [4, 5]
]
batch2_tensors = [
torch.arange(6).expand(4, 6), # [4, 6]
torch.arange(7).expand(4, 7), # [4, 7]
]

nested_batch1 = torch.nested.as_nested_tensor(batch1_tensors, layout=torch.jagged)
nested_batch2 = torch.nested.as_nested_tensor(batch2_tensors, layout=torch.jagged)

# Concatenate the nested tensors
output = tu.concat_nested_tensors([nested_batch1, nested_batch2])

# Verify the output has 4 batch elements
assert output.shape[0] == 4

# Verify each batch element has the correct shape and values
output_unbind = output.unbind(0)

assert torch.all(torch.eq(output_unbind[0], batch1_tensors[0])).item()
assert torch.all(torch.eq(output_unbind[1], batch1_tensors[1])).item()
assert torch.all(torch.eq(output_unbind[2], batch2_tensors[0])).item()
assert torch.all(torch.eq(output_unbind[3], batch2_tensors[1])).item()

# Verify shapes are preserved
assert output_unbind[0].shape == torch.Size([4, 4])
assert output_unbind[1].shape == torch.Size([4, 5])
assert output_unbind[2].shape == torch.Size([4, 6])
assert output_unbind[3].shape == torch.Size([4, 7])


def test_chunk_concat_roundtrip_3d():
"""Test that chunk and concat are inverse operations for 3D nested tensors."""
# Create a TensorDict with 3D nested tensors
position_ids = torch.nested.as_nested_tensor(
[
torch.arange(4).expand(4, 4),
torch.arange(5).expand(4, 5),
torch.arange(6).expand(4, 6),
torch.arange(7).expand(4, 7),
],
layout=torch.jagged,
)

# Also add regular 2D nested tensor for comparison
input_ids = torch.nested.as_nested_tensor(
[torch.arange(4), torch.arange(5), torch.arange(6), torch.arange(7)], layout=torch.jagged
)

original_td = tu.get_tensordict(
{
"input_ids": input_ids,
"position_ids": position_ids,
},
)

# Chunk the TensorDict
chunks = tu.chunk_tensordict(original_td, chunks=2)
assert len(chunks) == 2
assert len(chunks[0]) == 2
assert len(chunks[1]) == 2

# Concatenate the chunks back together
reconstructed_td = tu.concat_tensordict(chunks)

# Verify the reconstructed TensorDict matches the original
assert len(reconstructed_td) == 4

# Verify input_ids (2D nested tensor)
assert torch.all(torch.eq(reconstructed_td["input_ids"].values(), original_td["input_ids"].values())).item()

# Verify position_ids (3D nested tensor)
assert torch.all(torch.eq(reconstructed_td["position_ids"].values(), original_td["position_ids"].values())).item()

# Verify offsets match
assert torch.all(torch.eq(reconstructed_td["input_ids"].offsets(), original_td["input_ids"].offsets())).item()

assert torch.all(torch.eq(reconstructed_td["position_ids"].offsets(), original_td["position_ids"].offsets())).item()