diff --git a/tests/unittest/_torch/modules/test_moe_routing.py b/tests/unittest/_torch/modules/test_moe_routing.py index 141529c749e..53a6e0992b4 100644 --- a/tests/unittest/_torch/modules/test_moe_routing.py +++ b/tests/unittest/_torch/modules/test_moe_routing.py @@ -17,8 +17,11 @@ def test_default_moe_routing(top_k): logits = torch.tensor( [[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1], [0.1, 0.4, 0.2, 0.3]], - dtype=torch.float32) + dtype=torch.float32).cuda() indices, scales = routing.apply(logits) + indices = indices.cpu() + scales = scales.cpu() + assert indices.shape == (3, top_k) assert scales.shape == (3, top_k) @@ -26,7 +29,7 @@ def test_default_moe_routing(top_k): assert scales.dtype == torch.float32 reference_indices = torch.tensor([[3, 2, 1], [0, 1, 2], [1, 3, 2]], dtype=torch.int32) - reference_scales = F.softmax(logits, dim=1) + reference_scales = F.softmax(logits, dim=1).cpu() # Check that the selected experts are the largest top_k values for i in range(top_k): @@ -43,7 +46,6 @@ def test_default_moe_routing(top_k): reference_scales[2, reference_indices[2, i]]) -@pytest.mark.skip(reason="https://nvbugs/5332927") @pytest.mark.parametrize("top_k", [1, 2, 3]) def test_renormalize_moe_routing(top_k): routing = RenormalizeMoeRoutingMethod(top_k=top_k) @@ -51,7 +53,7 @@ def test_renormalize_moe_routing(top_k): logits = torch.tensor( [[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1], [0.1, 0.4, 0.2, 0.3]], - dtype=torch.float32) + dtype=torch.float32).cuda() indices, scales = routing.apply(logits) assert indices.shape == (3, top_k) assert scales.shape == (3, top_k) @@ -78,7 +80,7 @@ def gen_unique_logits(num_tokens, num_experts, dtype): return unique_logits.cuda() -@pytest.mark.parametrize("num_tokens", [1, 30, 2000]) +@pytest.mark.parametrize("num_tokens", [30]) @pytest.mark.parametrize("top_k", [2, 8]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32, torch.float16]) @@ -110,8 +112,10 @@ def test_sparse_mixer_reference(): [2.0, 0.0, -float('inf'), -float('inf')], [0.0, 2.0, -float('inf'), -float('inf')], [1.0, 1.0, 1.0, -float('inf')]], - dtype=torch.float32) + dtype=torch.float32).cuda() indices, scales = routing.apply(logits.clone()) + indices = indices.cpu() + scales = scales.cpu() assert indices.shape == (4, routing.experts_per_token) assert scales.shape == (4, routing.experts_per_token) @@ -147,7 +151,7 @@ def test_load_balanced_moe_routing(): assert routing.experts_per_token == k # Values don't matter for load balanced routing - logits = torch.empty((tokens, 4), dtype=torch.float32) + logits = torch.empty((tokens, 4), dtype=torch.float32).cuda() indices, scales = routing.apply(logits) assert indices.shape == (tokens, k) @@ -164,12 +168,14 @@ def test_load_balanced_moe_routing(): def test_static_moe_routing(): routing = StaticMoeRoutingMethod( - torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32)) + torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32).cuda()) assert routing.experts_per_token == 4 logits = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]], - dtype=torch.float32) + dtype=torch.float32).cuda() indices, scales = routing.apply(logits) + indices = indices.cpu() + assert scales is None assert indices.shape == (2, 4) assert indices.dtype == torch.int32 @@ -178,10 +184,12 @@ def test_static_moe_routing(): indices, torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32)) routing = StaticMoeRoutingMethod( - torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32), + torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]], dtype=torch.int32).cuda(), torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], - dtype=torch.float32)) + dtype=torch.float32).cuda()) indices, scales = routing.apply(logits) + scales = scales.cpu() + assert scales is not None assert scales.shape == (2, 4) assert scales.dtype == torch.float32