System Info
transformers version: 5.3.0
- Platform: Linux
- Python version: 3.13.5
- PyTorch version: 2.8.0+cu128
Who can help?
@Rocketknight1
Information
Reproduction
GlmMoeDsa models crash on any second forward pass. The DSA indexer's _cached_keys and _cached_indices persist between calls and cause shape mismatches or out-of-bounds scatter indices.
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("yujiepan/glm-5-tiny-random", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("yujiepan/glm-5-tiny-random")
inputs = tokenizer("Hello", return_tensors="pt").to(model.device)
# First forward: OK
out1 = model(**inputs)
print(out1.logits.shape) # torch.Size([1, 1, 154880])
# Second forward: CRASH
out2 = model(**inputs) # AcceleratorError: CUDA error: device-side assert triggered
Same issue with yujiepan/glm-moe-dsa-tiny-random.
Error
With CUDA_LAUNCH_BLOCKING=1:
File ".../transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py", line 414, in forward
index_mask.scatter_(-1, topk_indices, 0.0) # [B, S, T]
torch.AcceleratorError: CUDA error: device-side assert triggered
The underlying issue is at modeling_glm_moe_dsa.py:198:
k_cached = torch.cat([self._cached_keys, k], dim=1) # [B, T, D]
On the second forward call, self._cached_keys still holds stale state from the first call, leading to shape mismatches or invalid indices.
Expected behavior
The model should be callable multiple times without error. The DSA indexer should either reset its cache between forward passes or not use persistent state for inference without KV cache.
Additional context
This is related to other known GlmMoeDsa indexer issues (#44360, #44263). The stale cache issue compounds with those bugs — even if the indexer logic is fixed, the persistent cache between calls will continue to cause problems.
System Info
transformersversion: 5.3.0Who can help?
@Rocketknight1
Information
Reproduction
GlmMoeDsa models crash on any second forward pass. The DSA indexer's
_cached_keysand_cached_indicespersist between calls and cause shape mismatches or out-of-bounds scatter indices.Same issue with
yujiepan/glm-moe-dsa-tiny-random.Error
With
CUDA_LAUNCH_BLOCKING=1:The underlying issue is at
modeling_glm_moe_dsa.py:198:On the second forward call,
self._cached_keysstill holds stale state from the first call, leading to shape mismatches or invalid indices.Expected behavior
The model should be callable multiple times without error. The DSA indexer should either reset its cache between forward passes or not use persistent state for inference without KV cache.
Additional context
This is related to other known GlmMoeDsa indexer issues (#44360, #44263). The stale cache issue compounds with those bugs — even if the indexer logic is fixed, the persistent cache between calls will continue to cause problems.