-
Notifications
You must be signed in to change notification settings - Fork 14.1k
model: add glm-asr support #17901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
model: add glm-asr support #17901
Conversation
| cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); | ||
| cur = ggml_add(ctx0, cur, model.mm_1_b); | ||
| cur = ggml_gelu_erf(ctx0, cur); | ||
| cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); | ||
| cur = ggml_add(ctx0, cur, model.mm_2_b); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace this with build_ffn
| // whisper downscales input token by half after conv1d | ||
| n_patches /= 2; | ||
| // reshape by merge_factor | ||
| n_patches /= 4; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| n_patches /= 4; | |
| n_patches /= n_merge; |
and you also need to set hparams.n_merge = 4 upon loading hparams, see load_hparams() function
| cur = ggml_norm(ctx0, cur, hparams.eps); | ||
| cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); | ||
| cur = ggml_add(ctx0, cur, model.mm_norm_pre_b); | ||
| cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * 4, cur->ne[1] / 4); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will fail if number of elements is not divisible by 4
instead, you should abstract out the StackAudioFrames used by ultravox into a new function, build_stack(), and reuse it here
| "VLlama3ForCausalLM", | ||
| "LlavaForConditionalGeneration", | ||
| "VoxtralForConditionalGeneration", | ||
| "GlmasrModel", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "GlmasrModel", |
This will get overwritten by lm_config.
| if isinstance(self.hparams.get("eos_token_id"), list): | ||
| from transformers import AutoTokenizer | ||
| tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) | ||
| special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) | ||
| special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) | ||
| special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) | ||
| special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) | ||
| special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) | ||
| special_vocab.add_to_gguf(self.gguf_writer) | ||
| special_vocab.chat_template = "glmedge" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not ok, check for root architecture instead (see Qwen3MoeModel), also you should not need to set any of those special tokens.
Also, setting the template name like that doesn't work any more I think, and it's a dirty hack to begin with, if the model creators can't be bothered, neither should we.
| special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) | ||
| special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) | ||
| special_vocab.add_to_gguf(self.gguf_writer) | ||
| special_vocab.chat_template = "glmedge" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as always, it would be nice if GLM team can be more responsible for more carefully testing and distributing the chat template
we don't generally accept this kind of chat template hack anymore, as it is not supported by the jinja engine
Make sure to read the contributing guidelines before submitting a PR
This PR adds support for the GLM-ASR architecture, specifically validating with the zai-org/GLM-ASR-Nano-2512 model.
Key Changes:
convert_hf_to_gguf.pyto handle dynamic configuration keys (glm-asr use "lm_config" instead of text_config). It now correctly identifies the config section by checking:llm_config_key = "lm_config" if "lm_config" in self.hparams else "text_config"Result