Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions tests/unittest/_torch/modules/test_moe_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@ def test_default_moe_routing(top_k):

logits = torch.tensor(
[[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1], [0.1, 0.4, 0.2, 0.3]],
dtype=torch.float32)
dtype=torch.float32).cuda()
indices, scales = routing.apply(logits)
indices = indices.cpu()
scales = scales.cpu()

assert indices.shape == (3, top_k)
assert scales.shape == (3, top_k)

assert indices.dtype == torch.int32
assert scales.dtype == torch.float32
reference_indices = torch.tensor([[3, 2, 1], [0, 1, 2], [1, 3, 2]],
dtype=torch.int32)
reference_scales = F.softmax(logits, dim=1)
reference_scales = F.softmax(logits, dim=1).cpu()

# Check that the selected experts are the largest top_k values
for i in range(top_k):
Expand All @@ -43,15 +46,14 @@ def test_default_moe_routing(top_k):
reference_scales[2, reference_indices[2, i]])


@pytest.mark.skip(reason="https://nvbugs/5332927")
@pytest.mark.parametrize("top_k", [1, 2, 3])
def test_renormalize_moe_routing(top_k):
routing = RenormalizeMoeRoutingMethod(top_k=top_k)
assert routing.experts_per_token == top_k

logits = torch.tensor(
[[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1], [0.1, 0.4, 0.2, 0.3]],
dtype=torch.float32)
dtype=torch.float32).cuda()
indices, scales = routing.apply(logits)
assert indices.shape == (3, top_k)
assert scales.shape == (3, top_k)
Expand All @@ -78,7 +80,7 @@ def gen_unique_logits(num_tokens, num_experts, dtype):
return unique_logits.cuda()


@pytest.mark.parametrize("num_tokens", [1, 30, 2000])
@pytest.mark.parametrize("num_tokens", [30])
@pytest.mark.parametrize("top_k", [2, 8])
@pytest.mark.parametrize("dtype",
[torch.bfloat16, torch.float32, torch.float16])
Expand Down Expand Up @@ -110,8 +112,10 @@ def test_sparse_mixer_reference():
[2.0, 0.0, -float('inf'), -float('inf')],
[0.0, 2.0, -float('inf'), -float('inf')],
[1.0, 1.0, 1.0, -float('inf')]],
dtype=torch.float32)
dtype=torch.float32).cuda()
indices, scales = routing.apply(logits.clone())
indices = indices.cpu()
scales = scales.cpu()

assert indices.shape == (4, routing.experts_per_token)
assert scales.shape == (4, routing.experts_per_token)
Expand Down Expand Up @@ -147,7 +151,7 @@ def test_load_balanced_moe_routing():
assert routing.experts_per_token == k

# Values don't matter for load balanced routing
logits = torch.empty((tokens, 4), dtype=torch.float32)
logits = torch.empty((tokens, 4), dtype=torch.float32).cuda()

indices, scales = routing.apply(logits)
assert indices.shape == (tokens, k)
Expand All @@ -164,12 +168,14 @@ def test_load_balanced_moe_routing():

def test_static_moe_routing():
routing = StaticMoeRoutingMethod(
torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32))
torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32).cuda())
assert routing.experts_per_token == 4

logits = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]],
dtype=torch.float32)
dtype=torch.float32).cuda()
indices, scales = routing.apply(logits)
indices = indices.cpu()

assert scales is None
assert indices.shape == (2, 4)
assert indices.dtype == torch.int32
Expand All @@ -178,10 +184,12 @@ def test_static_moe_routing():
indices, torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32))

routing = StaticMoeRoutingMethod(
torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32),
torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32).cuda(),
torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]],
dtype=torch.float32))
dtype=torch.float32).cuda())
indices, scales = routing.apply(logits)
scales = scales.cpu()

assert scales is not None
assert scales.shape == (2, 4)
assert scales.dtype == torch.float32
Expand Down