Skip to content

permute different with torch version #1323

@hanlinxuy

Description

@hanlinxuy
import torch
from permute_unpermute import cuda_token_permute, cuda_token_unpermute, cuda_token_permute_torch, cuda_token_unpermute_torch

torch.manual_seed(1)
device = torch.device("cuda")
hidden_states = torch.randn(16, 64, device=device).to(torch.bfloat16)
router_logits = torch.randn(16, 32, device=device).to(torch.bfloat16)
top_k = 4

router_top_value, router_indices = torch.topk(router_logits, top_k, dim=-1)  # (seq_len, top_k)

hidden_states1, row_id_map1 = cuda_token_permute_torch(hidden_states, router_indices)
hidden_states2, row_id_map2 = cuda_token_permute(hidden_states, router_indices)

print(row_id_map1)
print(row_id_map2)

this is a simple script but I got different row_id_map results, could some help me to understand why this happen?

code were ran with L40 and cu128

enviroment are below.

uv pip list
Package                  Version   Editable project location
------------------------ --------- ------------------------------------------------------------------------
absl-py                  2.3.1
filelock                 3.20.0
fsspec                   2025.10.0
grouped-gemm             1.1.4
jinja2                   3.1.6
markupsafe               3.0.3
mpmath                   1.3.0
networkx                 3.6
numpy                    2.3.5
nvidia-cublas-cu12       12.8.4.1
nvidia-cuda-cupti-cu12   12.8.90
nvidia-cuda-nvrtc-cu12   12.8.93
nvidia-cuda-runtime-cu12 12.8.90
nvidia-cudnn-cu12        9.10.2.21
nvidia-cufft-cu12        11.3.3.83
nvidia-cufile-cu12       1.13.1.3
nvidia-curand-cu12       10.3.9.90
nvidia-cusolver-cu12     11.7.3.90
nvidia-cusparse-cu12     12.5.8.93
nvidia-cusparselt-cu12   0.7.1
nvidia-nccl-cu12         2.27.5
nvidia-nvjitlink-cu12    12.8.93
nvidia-nvshmem-cu12      3.3.20
nvidia-nvtx-cu12         12.8.90
setuptools               80.9.0
sympy                    1.14.0
torch                    2.9.1
triton                   3.5.1
typing-extensions        4.15.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions