Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[model] add glm-asr support
  • Loading branch information
piDack committed Dec 10, 2025
commit f432a5c3ba0335da6dea91ad4d4101302d54c14f
83 changes: 77 additions & 6 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,9 +736,10 @@ def __init__(self, *args, **kwargs):
else:
self.hf_arch = ""

if "text_config" in self.hparams:
llm_config_key = "lm_config" if "lm_config" in self.hparams else "text_config"
if llm_config_key in self.hparams:
# move the text_config to the root level
self.hparams = {**self.hparams, **self.hparams["text_config"]}
self.hparams = {**self.hparams, **self.hparams[llm_config_key]}

self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
Expand Down Expand Up @@ -1604,7 +1605,7 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]

n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"]

has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
Expand All @@ -1621,11 +1622,12 @@ def __init__(self, *args, **kwargs):

# get n_embd of the text model
if not self.is_mistral_format:
if "text_config" not in self.hparams:
llm_config_key = "lm_config" if "lm_config" in self.hparams else "text_config"
if llm_config_key not in self.hparams:
self.hparams["text_config"] = {}
if "audio_config" not in self.hparams:
self.hparams["audio_config"] = {}
text_config = {**self.hparams, **self.hparams["text_config"]}
text_config = {**self.hparams, **self.hparams[llm_config_key]}
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
else:
text_config = {
Expand Down Expand Up @@ -1680,7 +1682,8 @@ def get_vision_config(self) -> dict[str, Any] | None:
return self.global_config.get(config_name)

def get_audio_config(self) -> dict[str, Any] | None:
return self.global_config.get("audio_config")
mm_config_key = "whisper_config" if "whisper_config" in self.hparams else "audio_config"
return self.global_config.get(mm_config_key)

def set_type(self):
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
Expand Down Expand Up @@ -2356,6 +2359,7 @@ def prepare_tensors(self):
"VLlama3ForCausalLM",
"LlavaForConditionalGeneration",
"VoxtralForConditionalGeneration",
"GlmasrModel",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"GlmasrModel",

This will get overwritten by lm_config.

"LlamaModel")
class LlamaModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA
Expand Down Expand Up @@ -2407,6 +2411,17 @@ def set_vocab(self):
# Apply to granite small models only
if self.hparams.get("vocab_size", 32000) == 49152:
self.gguf_writer.add_add_bos_token(False)

if isinstance(self.hparams.get("eos_token_id"), list):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)
special_vocab.chat_template = "glmedge"
Comment on lines +2414 to +2423
Copy link
Collaborator

@CISC CISC Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not ok, check for root architecture instead (see Qwen3MoeModel), also you should not need to set any of those special tokens.

Also, setting the template name like that doesn't work any more I think, and it's a dirty hack to begin with, if the model creators can't be bothered, neither should we.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as always, it would be nice if GLM team can be more responsible for more carefully testing and distributing the chat template

we don't generally accept this kind of chat template hack anymore, as it is not supported by the jinja engine


def set_gguf_parameters(self):
super().set_gguf_parameters()
Expand Down Expand Up @@ -2443,6 +2458,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
"vision_language_adapter.",
"patch_merger.",
"pre_mm_projector_norm",
"audio_encoder.",
]

is_multimodal_tensor = "vision_tower" in name \
Expand Down Expand Up @@ -8998,6 +9014,61 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument")

@ModelBase.register("GlmasrModel")
class GlmASRWhisperEncoderModel(MmprojModel):
has_vision_encoder = False
has_audio_encoder = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if "hidden_size" not in self.hparams and "intermediate_size" not in self.hparams:
self.hparams["hidden_size"] = self.hparams["d_model"]
self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"]
self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"]

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLMA)
self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"])
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))

def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".conv" in name and ".weight" in name:
return gguf.GGMLQuantizationType.F16
return super().tensor_force_quant(name, new_name, bid, n_dims)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if name.startswith("model.") or name.startswith("lm_head."):
# skip language model tensors
return []

if name.startswith("audio_encoder.whisper."):
name = name.replace("audio_encoder.whisper.","audio_tower.")
if "audio_encoder.layer_norm." in name or "audio_encoder.proj." in name:
name = name.replace("audio_encoder.", "audio_encoder.adapting.")

if name.startswith("audio_encoder.audio_bos_eos_token."):
return [(self.map_tensor_name("model.vision.boi"), data_torch[0]), (self.map_tensor_name("model.vision.eoi"), data_torch[1])]

if name.startswith("audio_encoder.adapting."):
name = name.replace("audio_encoder.adapting.","audio.multi_modal_projector.")
if ".layer_norm." in name:
name = name.replace(".layer_norm.", ".ln_pre.")
if ".0." in name:
name = name.replace(".0.", ".linear_1.")
if ".2." in name:
name = name.replace(".2.", ".linear_2.")
if ".proj." in name:
print("skip proj")
return []

if "conv1.bias" in name or "conv2.bias" in name:
# transpose conv1 and conv2 bias
data_torch = data_torch.unsqueeze(-1)

return [(self.map_tensor_name(name), data_torch)]

@ModelBase.register("Qwen2AudioForConditionalGeneration")
class WhisperEncoderModel(MmprojModel):
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3320,6 +3320,7 @@ class VisionProjectorType:
ULTRAVOX = "ultravox"
INTERNVL = "internvl"
QWEN2A = "qwen2a" # audio
GLMA = "glma" # audio
QWEN25O = "qwen2.5o" # omni
VOXTRAL = "voxtral"
LFM2 = "lfm2"
Expand Down
2 changes: 2 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ enum projector_type {
PROJECTOR_TYPE_INTERNVL,
PROJECTOR_TYPE_LLAMA4,
PROJECTOR_TYPE_QWEN2A,
PROJECTOR_TYPE_GLMA,
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_LFM2,
Expand All @@ -175,6 +176,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
{ PROJECTOR_TYPE_GLMA, "glma"},
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
{ PROJECTOR_TYPE_LFM2, "lfm2"},
Expand Down
44 changes: 43 additions & 1 deletion tools/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ struct clip_model {
ggml_tensor * conv1d_2_w = nullptr;
ggml_tensor * conv1d_2_b = nullptr;
ggml_tensor * mm_norm_pre_w = nullptr;
ggml_tensor * mm_norm_pre_b = nullptr;
ggml_tensor * mm_norm_mid_w = nullptr;

// cogvlm
Expand Down Expand Up @@ -1829,7 +1830,6 @@ struct clip_graph {
GGML_ASSERT(model.layers[0].q_b);
GGML_ASSERT(model.layers[0].v_b);
GGML_ASSERT(!model.layers[0].k_b); // no bias for k
GGML_ASSERT(model.post_ln_w && model.post_ln_b);

ggml_tensor * pos_embd_selected = ggml_view_2d(
ctx0, model.position_embeddings,
Expand Down Expand Up @@ -1891,6 +1891,18 @@ struct clip_graph {
cur = ggml_gelu_erf(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);

} else if (ctx->proj_type() == PROJECTOR_TYPE_GLMA) {
cur = ggml_norm(ctx0, cur, hparams.eps);
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
cur = ggml_add(ctx0, cur, model.mm_norm_pre_b);
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * 4, cur->ne[1] / 4);
Copy link
Collaborator

@ngxson ngxson Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will fail if number of elements is not divisible by 4

instead, you should abstract out the StackAudioFrames used by ultravox into a new function, build_stack(), and reuse it here

cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
cur = ggml_add(ctx0, cur, model.mm_1_b);
cur = ggml_gelu_erf(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
cur = ggml_add(ctx0, cur, model.mm_2_b);
Comment on lines +1899 to +1903
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace this with build_ffn

cur = ggml_concat(ctx0, model.mm_boi, cur, 1);
cur = ggml_concat(ctx0, cur, model.mm_eoi, 1);
} else {
GGML_ABORT("%s: unknown projector type", __func__);
}
Expand Down Expand Up @@ -2518,6 +2530,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_GLMA:
{
res = graph.build_whisper_enc();
} break;
Expand Down Expand Up @@ -3225,6 +3238,21 @@ struct clip_model_loader {
model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
} break;
case PROJECTOR_TYPE_GLMA:
{
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias"));
model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias"));
model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
model.mm_norm_pre_b = get_tensor(string_format(TN_MM_NORM_PRE, "bias"));
model.mm_boi = get_tensor(string_format(TN_TOK_BOI, "weight"));
model.mm_eoi = get_tensor(string_format(TN_TOK_EOI, "weight"));
} break;
case PROJECTOR_TYPE_LLAMA4:
{
model.mm_model_proj = get_tensor(TN_MM_PROJECTOR);
Expand Down Expand Up @@ -4606,6 +4634,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
n_patches /= 2;
}
} break;
case PROJECTOR_TYPE_GLMA:
{
n_patches = img->nx;
// whisper downscales input token by half after conv1d
n_patches /= 2;
// reshape by merge_factor
n_patches /= 4;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
n_patches /= 4;
n_patches /= n_merge;

and you also need to set hparams.n_merge = 4 upon loading hparams, see load_hparams() function

// for BOI and EOI token embeddings
n_patches += 2;
} break;
case PROJECTOR_TYPE_COGVLM:
{
n_patches += 2; // for BOI and EOI token embeddings
Expand Down Expand Up @@ -4941,6 +4979,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_GLMA:
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_VOXTRAL:
Expand Down Expand Up @@ -5051,6 +5090,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_model_proj->ne[1];
case PROJECTOR_TYPE_QWEN2A:
return ctx->model.mm_fc_w->ne[1];
case PROJECTOR_TYPE_GLMA:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
return ctx->model.mm_2_w->ne[1];
Expand Down Expand Up @@ -5097,6 +5138,7 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN2A
|| ctx->proj_type() == PROJECTOR_TYPE_GLMA
|| ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL;
}

Expand Down
Loading