-
Notifications
You must be signed in to change notification settings - Fork 394
Open
Description
import torch
from permute_unpermute import cuda_token_permute, cuda_token_unpermute, cuda_token_permute_torch, cuda_token_unpermute_torch
torch.manual_seed(1)
device = torch.device("cuda")
hidden_states = torch.randn(16, 64, device=device).to(torch.bfloat16)
router_logits = torch.randn(16, 32, device=device).to(torch.bfloat16)
top_k = 4
router_top_value, router_indices = torch.topk(router_logits, top_k, dim=-1) # (seq_len, top_k)
hidden_states1, row_id_map1 = cuda_token_permute_torch(hidden_states, router_indices)
hidden_states2, row_id_map2 = cuda_token_permute(hidden_states, router_indices)
print(row_id_map1)
print(row_id_map2)this is a simple script but I got different row_id_map results, could some help me to understand why this happen?
code were ran with L40 and cu128
enviroment are below.
uv pip list
Package Version Editable project location
------------------------ --------- ------------------------------------------------------------------------
absl-py 2.3.1
filelock 3.20.0
fsspec 2025.10.0
grouped-gemm 1.1.4
jinja2 3.1.6
markupsafe 3.0.3
mpmath 1.3.0
networkx 3.6
numpy 2.3.5
nvidia-cublas-cu12 12.8.4.1
nvidia-cuda-cupti-cu12 12.8.90
nvidia-cuda-nvrtc-cu12 12.8.93
nvidia-cuda-runtime-cu12 12.8.90
nvidia-cudnn-cu12 9.10.2.21
nvidia-cufft-cu12 11.3.3.83
nvidia-cufile-cu12 1.13.1.3
nvidia-curand-cu12 10.3.9.90
nvidia-cusolver-cu12 11.7.3.90
nvidia-cusparse-cu12 12.5.8.93
nvidia-cusparselt-cu12 0.7.1
nvidia-nccl-cu12 2.27.5
nvidia-nvjitlink-cu12 12.8.93
nvidia-nvshmem-cu12 3.3.20
nvidia-nvtx-cu12 12.8.90
setuptools 80.9.0
sympy 1.14.0
torch 2.9.1
triton 3.5.1
typing-extensions 4.15.0
Metadata
Metadata
Assignees
Labels
No labels