-
Notifications
You must be signed in to change notification settings - Fork 14.1k
model: add glm-asr support #17901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
model: add glm-asr support #17901
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -736,9 +736,10 @@ def __init__(self, *args, **kwargs): | |
| else: | ||
| self.hf_arch = "" | ||
|
|
||
| if "text_config" in self.hparams: | ||
| llm_config_key = "lm_config" if "lm_config" in self.hparams else "text_config" | ||
| if llm_config_key in self.hparams: | ||
| # move the text_config to the root level | ||
| self.hparams = {**self.hparams, **self.hparams["text_config"]} | ||
| self.hparams = {**self.hparams, **self.hparams[llm_config_key]} | ||
|
|
||
| self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) | ||
| self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) | ||
|
|
@@ -1604,7 +1605,7 @@ class MmprojModel(ModelBase): | |
| preprocessor_config: dict[str, Any] | ||
| global_config: dict[str, Any] | ||
|
|
||
| n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"] | ||
| n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] | ||
|
|
||
| has_vision_encoder: bool = True # by default | ||
| has_audio_encoder: bool = False | ||
|
|
@@ -1621,11 +1622,12 @@ def __init__(self, *args, **kwargs): | |
|
|
||
| # get n_embd of the text model | ||
| if not self.is_mistral_format: | ||
| if "text_config" not in self.hparams: | ||
| llm_config_key = "lm_config" if "lm_config" in self.hparams else "text_config" | ||
| if llm_config_key not in self.hparams: | ||
| self.hparams["text_config"] = {} | ||
| if "audio_config" not in self.hparams: | ||
| self.hparams["audio_config"] = {} | ||
| text_config = {**self.hparams, **self.hparams["text_config"]} | ||
| text_config = {**self.hparams, **self.hparams[llm_config_key]} | ||
| self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0)) | ||
| else: | ||
| text_config = { | ||
|
|
@@ -1680,7 +1682,8 @@ def get_vision_config(self) -> dict[str, Any] | None: | |
| return self.global_config.get(config_name) | ||
|
|
||
| def get_audio_config(self) -> dict[str, Any] | None: | ||
| return self.global_config.get("audio_config") | ||
| mm_config_key = "whisper_config" if "whisper_config" in self.hparams else "audio_config" | ||
| return self.global_config.get(mm_config_key) | ||
|
|
||
| def set_type(self): | ||
| self.gguf_writer.add_type(gguf.GGUFType.MMPROJ) | ||
|
|
@@ -2356,6 +2359,7 @@ def prepare_tensors(self): | |
| "VLlama3ForCausalLM", | ||
| "LlavaForConditionalGeneration", | ||
| "VoxtralForConditionalGeneration", | ||
| "GlmasrModel", | ||
| "LlamaModel") | ||
| class LlamaModel(TextModel): | ||
| model_arch = gguf.MODEL_ARCH.LLAMA | ||
|
|
@@ -2407,6 +2411,17 @@ def set_vocab(self): | |
| # Apply to granite small models only | ||
| if self.hparams.get("vocab_size", 32000) == 49152: | ||
| self.gguf_writer.add_add_bos_token(False) | ||
|
|
||
| if isinstance(self.hparams.get("eos_token_id"), list): | ||
| from transformers import AutoTokenizer | ||
| tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) | ||
| special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) | ||
| special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) | ||
| special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) | ||
| special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) | ||
| special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) | ||
| special_vocab.add_to_gguf(self.gguf_writer) | ||
| special_vocab.chat_template = "glmedge" | ||
|
Comment on lines
+2414
to
+2423
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not ok, check for root architecture instead (see Also, setting the template name like that doesn't work any more I think, and it's a dirty hack to begin with, if the model creators can't be bothered, neither should we.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as always, it would be nice if GLM team can be more responsible for more carefully testing and distributing the chat template we don't generally accept this kind of chat template hack anymore, as it is not supported by the jinja engine |
||
|
|
||
| def set_gguf_parameters(self): | ||
| super().set_gguf_parameters() | ||
|
|
@@ -2443,6 +2458,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter | |
| "vision_language_adapter.", | ||
| "patch_merger.", | ||
| "pre_mm_projector_norm", | ||
| "audio_encoder.", | ||
| ] | ||
|
|
||
| is_multimodal_tensor = "vision_tower" in name \ | ||
|
|
@@ -8998,6 +9014,61 @@ def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | ||
| raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument") | ||
|
|
||
| @ModelBase.register("GlmasrModel") | ||
| class GlmASRWhisperEncoderModel(MmprojModel): | ||
| has_vision_encoder = False | ||
| has_audio_encoder = True | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| if "hidden_size" not in self.hparams and "intermediate_size" not in self.hparams: | ||
| self.hparams["hidden_size"] = self.hparams["d_model"] | ||
| self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"] | ||
| self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"] | ||
|
|
||
| def set_gguf_parameters(self): | ||
| super().set_gguf_parameters() | ||
| self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLMA) | ||
| self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"]) | ||
| self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5)) | ||
|
|
||
| def tensor_force_quant(self, name, new_name, bid, n_dims): | ||
| if ".conv" in name and ".weight" in name: | ||
| return gguf.GGMLQuantizationType.F16 | ||
| return super().tensor_force_quant(name, new_name, bid, n_dims) | ||
|
|
||
| def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||
| del bid # unused | ||
|
|
||
| if name.startswith("model.") or name.startswith("lm_head."): | ||
| # skip language model tensors | ||
| return [] | ||
|
|
||
| if name.startswith("audio_encoder.whisper."): | ||
| name = name.replace("audio_encoder.whisper.","audio_tower.") | ||
| if "audio_encoder.layer_norm." in name or "audio_encoder.proj." in name: | ||
| name = name.replace("audio_encoder.", "audio_encoder.adapting.") | ||
|
|
||
| if name.startswith("audio_encoder.audio_bos_eos_token."): | ||
| return [(self.map_tensor_name("model.vision.boi"), data_torch[0]), (self.map_tensor_name("model.vision.eoi"), data_torch[1])] | ||
|
|
||
| if name.startswith("audio_encoder.adapting."): | ||
| name = name.replace("audio_encoder.adapting.","audio.multi_modal_projector.") | ||
| if ".layer_norm." in name: | ||
| name = name.replace(".layer_norm.", ".ln_pre.") | ||
| if ".0." in name: | ||
| name = name.replace(".0.", ".linear_1.") | ||
| if ".2." in name: | ||
| name = name.replace(".2.", ".linear_2.") | ||
| if ".proj." in name: | ||
| print("skip proj") | ||
| return [] | ||
|
|
||
| if "conv1.bias" in name or "conv2.bias" in name: | ||
| # transpose conv1 and conv2 bias | ||
| data_torch = data_torch.unsqueeze(-1) | ||
|
|
||
| return [(self.map_tensor_name(name), data_torch)] | ||
|
|
||
| @ModelBase.register("Qwen2AudioForConditionalGeneration") | ||
| class WhisperEncoderModel(MmprojModel): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -388,6 +388,7 @@ struct clip_model { | |||||
| ggml_tensor * conv1d_2_w = nullptr; | ||||||
| ggml_tensor * conv1d_2_b = nullptr; | ||||||
| ggml_tensor * mm_norm_pre_w = nullptr; | ||||||
| ggml_tensor * mm_norm_pre_b = nullptr; | ||||||
| ggml_tensor * mm_norm_mid_w = nullptr; | ||||||
|
|
||||||
| // cogvlm | ||||||
|
|
@@ -1829,7 +1830,6 @@ struct clip_graph { | |||||
| GGML_ASSERT(model.layers[0].q_b); | ||||||
| GGML_ASSERT(model.layers[0].v_b); | ||||||
| GGML_ASSERT(!model.layers[0].k_b); // no bias for k | ||||||
| GGML_ASSERT(model.post_ln_w && model.post_ln_b); | ||||||
|
|
||||||
| ggml_tensor * pos_embd_selected = ggml_view_2d( | ||||||
| ctx0, model.position_embeddings, | ||||||
|
|
@@ -1891,6 +1891,18 @@ struct clip_graph { | |||||
| cur = ggml_gelu_erf(ctx0, cur); | ||||||
| cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); | ||||||
|
|
||||||
| } else if (ctx->proj_type() == PROJECTOR_TYPE_GLMA) { | ||||||
| cur = ggml_norm(ctx0, cur, hparams.eps); | ||||||
| cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); | ||||||
| cur = ggml_add(ctx0, cur, model.mm_norm_pre_b); | ||||||
| cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * 4, cur->ne[1] / 4); | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will fail if number of elements is not divisible by 4 instead, you should abstract out the |
||||||
| cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); | ||||||
| cur = ggml_add(ctx0, cur, model.mm_1_b); | ||||||
| cur = ggml_gelu_erf(ctx0, cur); | ||||||
| cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); | ||||||
| cur = ggml_add(ctx0, cur, model.mm_2_b); | ||||||
|
Comment on lines
+1899
to
+1903
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. replace this with build_ffn |
||||||
| cur = ggml_concat(ctx0, model.mm_boi, cur, 1); | ||||||
| cur = ggml_concat(ctx0, cur, model.mm_eoi, 1); | ||||||
| } else { | ||||||
| GGML_ABORT("%s: unknown projector type", __func__); | ||||||
| } | ||||||
|
|
@@ -2518,6 +2530,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 | |||||
| case PROJECTOR_TYPE_ULTRAVOX: | ||||||
| case PROJECTOR_TYPE_VOXTRAL: | ||||||
| case PROJECTOR_TYPE_QWEN2A: | ||||||
| case PROJECTOR_TYPE_GLMA: | ||||||
| { | ||||||
| res = graph.build_whisper_enc(); | ||||||
| } break; | ||||||
|
|
@@ -3225,6 +3238,21 @@ struct clip_model_loader { | |||||
| model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); | ||||||
| model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); | ||||||
| } break; | ||||||
| case PROJECTOR_TYPE_GLMA: | ||||||
| { | ||||||
| model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); | ||||||
| model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); | ||||||
| model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); | ||||||
| model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); | ||||||
| model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); | ||||||
| model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias")); | ||||||
| model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight")); | ||||||
| model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias")); | ||||||
| model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight")); | ||||||
| model.mm_norm_pre_b = get_tensor(string_format(TN_MM_NORM_PRE, "bias")); | ||||||
| model.mm_boi = get_tensor(string_format(TN_TOK_BOI, "weight")); | ||||||
| model.mm_eoi = get_tensor(string_format(TN_TOK_EOI, "weight")); | ||||||
| } break; | ||||||
| case PROJECTOR_TYPE_LLAMA4: | ||||||
| { | ||||||
| model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); | ||||||
|
|
@@ -4606,6 +4634,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im | |||||
| n_patches /= 2; | ||||||
| } | ||||||
| } break; | ||||||
| case PROJECTOR_TYPE_GLMA: | ||||||
| { | ||||||
| n_patches = img->nx; | ||||||
| // whisper downscales input token by half after conv1d | ||||||
| n_patches /= 2; | ||||||
| // reshape by merge_factor | ||||||
| n_patches /= 4; | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
and you also need to set |
||||||
| // for BOI and EOI token embeddings | ||||||
| n_patches += 2; | ||||||
| } break; | ||||||
| case PROJECTOR_TYPE_COGVLM: | ||||||
| { | ||||||
| n_patches += 2; // for BOI and EOI token embeddings | ||||||
|
|
@@ -4941,6 +4979,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima | |||||
| case PROJECTOR_TYPE_IDEFICS3: | ||||||
| case PROJECTOR_TYPE_INTERNVL: | ||||||
| case PROJECTOR_TYPE_QWEN2A: | ||||||
| case PROJECTOR_TYPE_GLMA: | ||||||
| case PROJECTOR_TYPE_ULTRAVOX: | ||||||
| case PROJECTOR_TYPE_LFM2: | ||||||
| case PROJECTOR_TYPE_VOXTRAL: | ||||||
|
|
@@ -5051,6 +5090,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { | |||||
| return ctx->model.mm_model_proj->ne[1]; | ||||||
| case PROJECTOR_TYPE_QWEN2A: | ||||||
| return ctx->model.mm_fc_w->ne[1]; | ||||||
| case PROJECTOR_TYPE_GLMA: | ||||||
| return ctx->model.mm_2_w->ne[1]; | ||||||
| case PROJECTOR_TYPE_LFM2: | ||||||
| case PROJECTOR_TYPE_KIMIVL: | ||||||
| return ctx->model.mm_2_w->ne[1]; | ||||||
|
|
@@ -5097,6 +5138,7 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) { | |||||
| bool clip_has_whisper_encoder(const struct clip_ctx * ctx) { | ||||||
| return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX | ||||||
| || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A | ||||||
| || ctx->proj_type() == PROJECTOR_TYPE_GLMA | ||||||
| || ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL; | ||||||
| } | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will get overwritten by
lm_config.