Skip to content

Conversation

@piDack
Copy link
Contributor

@piDack piDack commented Dec 10, 2025

Make sure to read the contributing guidelines before submitting a PR

This PR adds support for the GLM-ASR architecture, specifically validating with the zai-org/GLM-ASR-Nano-2512 model.

Key Changes:

  • Model Support: Implemented necessary logic to support GLM-ASR models.
  • Conversion Script: Updated convert_hf_to_gguf.py to handle dynamic configuration keys (glm-asr use "lm_config" instead of text_config). It now correctly identifies the config section by checking:
    llm_config_key = "lm_config" if "lm_config" in self.hparams else "text_config"

Result

img_v3_02sr_36e8953d-e10a-4165-b587-5759da7d2deg

@piDack piDack changed the title [model] add glm-asr support model: add glm-asr support Dec 10, 2025
Comment on lines +1899 to +1903
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
cur = ggml_add(ctx0, cur, model.mm_1_b);
cur = ggml_gelu_erf(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
cur = ggml_add(ctx0, cur, model.mm_2_b);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace this with build_ffn

// whisper downscales input token by half after conv1d
n_patches /= 2;
// reshape by merge_factor
n_patches /= 4;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
n_patches /= 4;
n_patches /= n_merge;

and you also need to set hparams.n_merge = 4 upon loading hparams, see load_hparams() function

cur = ggml_norm(ctx0, cur, hparams.eps);
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
cur = ggml_add(ctx0, cur, model.mm_norm_pre_b);
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * 4, cur->ne[1] / 4);
Copy link
Collaborator

@ngxson ngxson Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will fail if number of elements is not divisible by 4

instead, you should abstract out the StackAudioFrames used by ultravox into a new function, build_stack(), and reuse it here

"VLlama3ForCausalLM",
"LlavaForConditionalGeneration",
"VoxtralForConditionalGeneration",
"GlmasrModel",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"GlmasrModel",

This will get overwritten by lm_config.

Comment on lines +2414 to +2423
if isinstance(self.hparams.get("eos_token_id"), list):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)
special_vocab.chat_template = "glmedge"
Copy link
Collaborator

@CISC CISC Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not ok, check for root architecture instead (see Qwen3MoeModel), also you should not need to set any of those special tokens.

Also, setting the template name like that doesn't work any more I think, and it's a dirty hack to begin with, if the model creators can't be bothered, neither should we.

special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)
special_vocab.chat_template = "glmedge"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as always, it would be nice if GLM team can be more responsible for more carefully testing and distributing the chat template

we don't generally accept this kind of chat template hack anymore, as it is not supported by the jinja engine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants