diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2cc2a388236..4acca39d88f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -736,9 +736,10 @@ def __init__(self, *args, **kwargs): else: self.hf_arch = "" - if "text_config" in self.hparams: + llm_config_key = "lm_config" if "lm_config" in self.hparams else "text_config" + if llm_config_key in self.hparams: # move the text_config to the root level - self.hparams = {**self.hparams, **self.hparams["text_config"]} + self.hparams = {**self.hparams, **self.hparams[llm_config_key]} self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) @@ -1604,7 +1605,7 @@ class MmprojModel(ModelBase): preprocessor_config: dict[str, Any] global_config: dict[str, Any] - n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"] + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] has_vision_encoder: bool = True # by default has_audio_encoder: bool = False @@ -1621,11 +1622,12 @@ def __init__(self, *args, **kwargs): # get n_embd of the text model if not self.is_mistral_format: - if "text_config" not in self.hparams: + llm_config_key = "lm_config" if "lm_config" in self.hparams else "text_config" + if llm_config_key not in self.hparams: self.hparams["text_config"] = {} if "audio_config" not in self.hparams: self.hparams["audio_config"] = {} - text_config = {**self.hparams, **self.hparams["text_config"]} + text_config = {**self.hparams, **self.hparams[llm_config_key]} self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0)) else: text_config = { @@ -1680,7 +1682,8 @@ def get_vision_config(self) -> dict[str, Any] | None: return self.global_config.get(config_name) def get_audio_config(self) -> dict[str, Any] | None: - return self.global_config.get("audio_config") + mm_config_key = "whisper_config" if "whisper_config" in self.hparams else "audio_config" + return self.global_config.get(mm_config_key) def set_type(self): self.gguf_writer.add_type(gguf.GGUFType.MMPROJ) @@ -2356,6 +2359,7 @@ def prepare_tensors(self): "VLlama3ForCausalLM", "LlavaForConditionalGeneration", "VoxtralForConditionalGeneration", + "GlmasrModel", "LlamaModel") class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA @@ -2407,6 +2411,16 @@ def set_vocab(self): # Apply to granite small models only if self.hparams.get("vocab_size", 32000) == 49152: self.gguf_writer.add_add_bos_token(False) + if isinstance(self.hparams.get("eos_token_id"), list): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab.add_to_gguf(self.gguf_writer) + special_vocab.chat_template = "glmedge" def set_gguf_parameters(self): super().set_gguf_parameters() @@ -2443,6 +2457,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter "vision_language_adapter.", "patch_merger.", "pre_mm_projector_norm", + "audio_encoder.", ] is_multimodal_tensor = "vision_tower" in name \ @@ -8999,6 +9014,62 @@ def __init__(self, *args, **kwargs): raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument") +@ModelBase.register("GlmasrModel") +class GlmASRWhisperEncoderModel(MmprojModel): + has_vision_encoder = False + has_audio_encoder = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if "hidden_size" not in self.hparams and "intermediate_size" not in self.hparams: + self.hparams["hidden_size"] = self.hparams["d_model"] + self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"] + self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLMA) + self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"]) + self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + if ".conv" in name and ".weight" in name: + return gguf.GGMLQuantizationType.F16 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.startswith("model.") or name.startswith("lm_head."): + # skip language model tensors + return [] + + if name.startswith("audio_encoder.whisper."): + name = name.replace("audio_encoder.whisper.","audio_tower.") + if "audio_encoder.layer_norm." in name or "audio_encoder.proj." in name: + name = name.replace("audio_encoder.", "audio_encoder.adapting.") + + if name.startswith("audio_encoder.audio_bos_eos_token."): + return [(self.map_tensor_name("model.vision.boi"), data_torch[0]), (self.map_tensor_name("model.vision.eoi"), data_torch[1])] + + if name.startswith("audio_encoder.adapting."): + name = name.replace("audio_encoder.adapting.","audio.multi_modal_projector.") + if ".layer_norm." in name: + name = name.replace(".layer_norm.", ".ln_pre.") + if ".0." in name: + name = name.replace(".0.", ".linear_1.") + if ".2." in name: + name = name.replace(".2.", ".linear_2.") + if ".proj." in name: + return [] + + if "conv1.bias" in name or "conv2.bias" in name: + # transpose conv1 and conv2 bias + data_torch = data_torch.unsqueeze(-1) + + return [(self.map_tensor_name(name), data_torch)] + + @ModelBase.register("Qwen2AudioForConditionalGeneration") class WhisperEncoderModel(MmprojModel): has_vision_encoder = False # no vision encoder diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 2b8489c591b..8ef4a23a104 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -3320,6 +3320,7 @@ class VisionProjectorType: ULTRAVOX = "ultravox" INTERNVL = "internvl" QWEN2A = "qwen2a" # audio + GLMA = "glma" # audio QWEN25O = "qwen2.5o" # omni VOXTRAL = "voxtral" LFM2 = "lfm2" diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index cd47865bf4a..93153226d45 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -149,6 +149,7 @@ enum projector_type { PROJECTOR_TYPE_INTERNVL, PROJECTOR_TYPE_LLAMA4, PROJECTOR_TYPE_QWEN2A, + PROJECTOR_TYPE_GLMA, PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx PROJECTOR_TYPE_VOXTRAL, PROJECTOR_TYPE_LFM2, @@ -175,6 +176,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_INTERNVL, "internvl"}, { PROJECTOR_TYPE_LLAMA4, "llama4"}, { PROJECTOR_TYPE_QWEN2A, "qwen2a"}, + { PROJECTOR_TYPE_GLMA, "glma"}, { PROJECTOR_TYPE_QWEN25O, "qwen2.5o"}, { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, { PROJECTOR_TYPE_LFM2, "lfm2"}, diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 3ed08a0fec6..32318aec5d5 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -388,6 +388,7 @@ struct clip_model { ggml_tensor * conv1d_2_w = nullptr; ggml_tensor * conv1d_2_b = nullptr; ggml_tensor * mm_norm_pre_w = nullptr; + ggml_tensor * mm_norm_pre_b = nullptr; ggml_tensor * mm_norm_mid_w = nullptr; // cogvlm @@ -1829,7 +1830,6 @@ struct clip_graph { GGML_ASSERT(model.layers[0].q_b); GGML_ASSERT(model.layers[0].v_b); GGML_ASSERT(!model.layers[0].k_b); // no bias for k - GGML_ASSERT(model.post_ln_w && model.post_ln_b); ggml_tensor * pos_embd_selected = ggml_view_2d( ctx0, model.position_embeddings, @@ -1891,6 +1891,18 @@ struct clip_graph { cur = ggml_gelu_erf(ctx0, cur); cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); + } else if (ctx->proj_type() == PROJECTOR_TYPE_GLMA) { + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); + cur = ggml_add(ctx0, cur, model.mm_norm_pre_b); + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * 4, cur->ne[1] / 4); + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + cur = ggml_add(ctx0, cur, model.mm_1_b); + cur = ggml_gelu_erf(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); + cur = ggml_add(ctx0, cur, model.mm_2_b); + cur = ggml_concat(ctx0, model.mm_boi, cur, 1); + cur = ggml_concat(ctx0, cur, model.mm_eoi, 1); } else { GGML_ABORT("%s: unknown projector type", __func__); } @@ -2518,6 +2530,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_GLMA: { res = graph.build_whisper_enc(); } break; @@ -3225,6 +3238,21 @@ struct clip_model_loader { model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); } break; + case PROJECTOR_TYPE_GLMA: + { + model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); + model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); + model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); + model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); + model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias")); + model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias")); + model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight")); + model.mm_norm_pre_b = get_tensor(string_format(TN_MM_NORM_PRE, "bias")); + model.mm_boi = get_tensor(string_format(TN_TOK_BOI, "weight")); + model.mm_eoi = get_tensor(string_format(TN_TOK_EOI, "weight")); + } break; case PROJECTOR_TYPE_LLAMA4: { model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); @@ -4606,6 +4634,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches /= 2; } } break; + case PROJECTOR_TYPE_GLMA: + { + n_patches = img->nx; + // whisper downscales input token by half after conv1d + n_patches /= 2; + // reshape by merge_factor + n_patches /= 4; + // for BOI and EOI token embeddings + n_patches += 2; + } break; case PROJECTOR_TYPE_COGVLM: { n_patches += 2; // for BOI and EOI token embeddings @@ -4941,6 +4979,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_VOXTRAL: @@ -5051,6 +5090,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_model_proj->ne[1]; case PROJECTOR_TYPE_QWEN2A: return ctx->model.mm_fc_w->ne[1]; + case PROJECTOR_TYPE_GLMA: + return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: return ctx->model.mm_2_w->ne[1]; @@ -5097,6 +5138,7 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) { bool clip_has_whisper_encoder(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A + || ctx->proj_type() == PROJECTOR_TYPE_GLMA || ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL; }