diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ce83f24695ec7..8142fa9328327 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -888,6 +888,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756": # ref: https://huggingface.co/JetBrains/Mellum-4b-base res = "mellum" + if chkhsh == "a0b64b4385f123663873756336c085744376d015ff328bb1d901598f63c44152": + # ref: https://huggingface.co/ibm-granite/granite-embedding-small-english-r2 + res = "modern-bert" if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206": # ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base res = "llada-moe" @@ -8733,6 +8736,35 @@ def prepare_tensors(self): experts = [k for d in self._experts for k in d.keys()] if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") + + +@ModelBase.register("ModernBertModel", "ModernBertForMaskedLM", "ModernBertForSequenceClassification") +class ModernBertModel(BertModel): + model_arch = gguf.MODEL_ARCH.MODERN_BERT + + def set_vocab(self): + self._set_vocab_gpt2() + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(True) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_sliding_window(self.hparams["local_attention"]) + self.gguf_writer.add_rope_freq_base(self.hparams["global_rope_theta"]) + self.gguf_writer.add_rope_freq_base_swa(self.hparams["local_rope_theta"]) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # These layers act as MLM head, so we don't need them + if name.startswith("decoder."): + return [] + + if name.startswith("model."): + name = name[6:] + + return super().modify_tensors(data_torch, name, bid) + class MistralModel(LlamaModel): diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 21bb4a9f3e5e6..8f2467194dee8 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -139,6 +139,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"}, {"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", }, {"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", }, + {"name": "modern-bert", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-embedding-small-english-r2", }, {"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", }, ] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 7e16cbcbde2cb..0c520233d25fa 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -158,6 +158,7 @@ class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" FREQ_BASE = "{arch}.rope.freq_base" + FREQ_BASE_SWA = "{arch}.rope.freq_base_swa" SCALING_TYPE = "{arch}.rope.scaling.type" SCALING_FACTOR = "{arch}.rope.scaling.factor" SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" @@ -322,6 +323,7 @@ class MODEL_ARCH(IntEnum): STARCODER = auto() REFACT = auto() BERT = auto() + MODERN_BERT = auto() NOMIC_BERT = auto() NOMIC_BERT_MOE = auto() NEO_BERT = auto() @@ -658,6 +660,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.STARCODER: "starcoder", MODEL_ARCH.REFACT: "refact", MODEL_ARCH.BERT: "bert", + MODEL_ARCH.MODERN_BERT: "modern-bert", MODEL_ARCH.NOMIC_BERT: "nomic-bert", MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", MODEL_ARCH.NEO_BERT: "neo-bert", @@ -1194,6 +1197,20 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.CLS, MODEL_TENSOR.CLS_OUT, ], + MODEL_ARCH.MODERN_BERT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.CLS, + MODEL_TENSOR.CLS_OUT, + ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index d925fca7e3e11..12b56fc2ec907 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -820,6 +820,9 @@ def add_iclr_lora_rank(self, length: int) -> None: def add_value_residual_mix_lora_rank(self, length: int) -> None: self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length) + def add_rope_freq_base_swa(self, value: float) -> None: + self.add_float32(Keys.Rope.FREQ_BASE_SWA.format(arch=self.arch), value) + def add_gate_lora_rank(self, length: int) -> None: self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 8fd9e454e0e81..1196e74f97919 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -17,6 +17,7 @@ class TensorNameMap: "embed_tokens", # embeddinggemma "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert + "embeddings.tok_embeddings", # modern bert "language_model.embedding.word_embeddings", # persimmon "wte", # gpt2 "transformer.embd.wte", # phi2 @@ -46,6 +47,7 @@ class TensorNameMap: MODEL_TENSOR.TOKEN_EMBD_NORM: ( "word_embeddings_layernorm", # bloom "embeddings.LayerNorm", # bert + "embeddings.norm", # modern bert "emb_ln", # nomic-bert "transformer.norm", # openelm "rwkv.blocks.0.pre_ln", # rwkv @@ -99,6 +101,7 @@ class TensorNameMap: "backbone.final_layer_norm", # wavtokenizer "model.norm", # llama4 "model.transformer.ln_f", # llada + "final_norm", # modern bert ), # Rope frequencies @@ -145,9 +148,10 @@ class TensorNameMap: "model.layers.{bid}.input_layernorm", # llama4 "layers.{bid}.input_layernorm", # embeddinggemma "transformer_encoder.{bid}.attention_norm", # neobert + "layers.{bid}.attn_norm", # bert "model.layers.{bid}.operator_norm", # lfm2 "model.transformer.blocks.{bid}.attn_norm", # llada - "layers.{bid}.input_layernorm", # qwen3-embedding + "layers.{bid}.input_layernorm", # qwen3-embedding, ), # Attention norm 2 @@ -177,6 +181,7 @@ class TensorNameMap: "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm "transformer_encoder.{bid}.qkv", # neobert + "layers.{bid}.attn.Wqkv", # modern bert ), # Attention query @@ -250,6 +255,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.linear_attn", # deci "layers.{bid}.attention.wo", # llama-pth "encoder.layer.{bid}.attention.output.dense", # bert + "layers.{bid}.attn.Wo", # modern bert "transformer.layer.{bid}.attention.out_lin", # distillbert "transformer.h.{bid}.attn.out_proj", # gpt-j "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon @@ -325,6 +331,7 @@ class TensorNameMap: "model.layers.layers.{bid}.pre_mlp_norm", # plamo2 "model.transformer.blocks.{bid}.ff_norm", # llada "layers.{bid}.post_attention_layernorm", # qwen3-embedding + "layers.{bid}.mlp_norm" # modern bert ), # Post feed-forward norm @@ -378,6 +385,7 @@ class TensorNameMap: "layers.{bid}.mlp.up_proj", # embeddinggemma "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert + "layers.{bid}.mlp.Wi", # modern bert "transformer.layer.{bid}.ffn.lin1", # distillbert "transformer.h.{bid}.mlp.fc_in", # gpt-j "transformer.h.{bid}.mlp.linear_3", # refact @@ -479,6 +487,7 @@ class TensorNameMap: "layers.{bid}.mlp.down_proj", # embeddinggemma "layers.{bid}.feed_forward.w2", # llama-pth "encoder.layer.{bid}.output.dense", # bert + "layers.{bid}.mlp.Wo", # modern bert "transformer.layer.{bid}.ffn.lin2", # distillbert "transformer.h.{bid}.mlp.fc_out", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a4d2973ada5dc..6c4b689f02b7d 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -18,6 +18,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER, "starcoder" }, { LLM_ARCH_REFACT, "refact" }, { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_MODERN_BERT, "modern-bert" }, { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, { LLM_ARCH_NEO_BERT, "neo-bert" }, @@ -176,9 +177,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, @@ -525,6 +528,23 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_CLS_OUT, "cls.output" }, }, }, + { + LLM_ARCH_MODERN_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, + }, + }, { LLM_ARCH_NOMIC_BERT, { diff --git a/src/llama-arch.h b/src/llama-arch.h index d181ce6784ffb..c3421181f7064 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -22,6 +22,7 @@ enum llm_arch { LLM_ARCH_STARCODER, LLM_ARCH_REFACT, LLM_ARCH_BERT, + LLM_ARCH_MODERN_BERT, LLM_ARCH_NOMIC_BERT, LLM_ARCH_NOMIC_BERT_MOE, LLM_ARCH_NEO_BERT, @@ -184,6 +185,7 @@ enum llm_kv { LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, LLM_KV_ROPE_SCALE_LINEAR, + LLM_KV_ROPE_FREQ_BASE_SWA, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, LLM_KV_ROPE_SCALING_ATTN_FACTOR, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 9f2e417f1ff4b..62e68a912d5ca 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1578,7 +1578,6 @@ ggml_tensor * llm_graph_context::build_attn( // optionally store to KV cache if (k_cur) { const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs(); - ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2be807a6a9dab..726e614a295e4 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -783,6 +783,27 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MODERN_BERT: + { + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + hparams.set_swa_pattern(3, 0); + hparams.rope_freq_base_train_swa = 10000.f; + hparams.rope_freq_base_train = 160000.f; + hparams.n_swa = 128; + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + + switch (hparams.n_layer) { + case 12: + type = LLM_TYPE_47M; break; // granite-embeddings-small + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -2275,7 +2296,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { first_moved_to_buft = buft; } } - + ggml_context * ctx = ctx_for_buft(buft); // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one @@ -2776,7 +2797,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_JINA_BERT_V3: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); if (arch == LLM_ARCH_BERT) { pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); @@ -2833,6 +2854,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); } } break; + case LLM_ARCH_MODERN_BERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for(int i = 0; i < n_layer; ++i) { + auto& layer = layers[i]; + + if ( i != 0 ) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + } + else{ + // layer 0 uses identity so we dont need weights for said layer + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + } + + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + } break; case LLM_ARCH_NEO_BERT: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7787,6 +7841,128 @@ struct llm_build_bert : public llm_graph_context { } }; +template +struct llm_build_modern_bert : public llm_graph_context { + llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const float rope_theta_global = hparams.rope_freq_base_train; + const float rope_theta_local = hparams.rope_freq_base_train_swa; + //const uint32_t n_swa_local = hparams.n_swa; + //const uint32_t n_swa_global = 4096; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur = nullptr; + ggml_tensor * inpL = nullptr; + ggml_tensor * inp_pos = build_inp_pos(); + + // construct input embeddings (token, type, position) + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "inp_embd", -1); + + // embed layer norm + inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1); + cb(inpL, "inp_norm", -1); + + auto * inp_attn = build_attn_inp_kv_iswa(); + + // iterate layers + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * cur = inpL; + + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + const float rope_theta = (il+1) % 3 == 0 ? rope_theta_global : rope_theta_local; + + + // attention layer norm + if (model.layers[il].attn_norm) { + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM, il); + cb(cur, "attn_norm", il); + } + + // self attention + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd*2))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, rope_theta, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, rope_theta, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "kqv_out", il); + + if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // re-add the layer input + cur = ggml_add(ctx0, cur, inpL); + + // attention layer norm + cur = build_norm(cur, model.layers[il].attn_out_norm, nullptr, LLM_NORM, il); + + ggml_tensor * ffn_inp = cur; + cb(ffn_inp, "ffn_inp", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, + NULL, NULL, NULL, NULL, NULL, + model.layers[il].ffn_down, + NULL, NULL, NULL, + LLM_FFN_GEGLU, LLM_FFN_SEQ, il); + + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, cur, ffn_inp); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm_enc, NULL, + LLM_NORM, -1); + + cb(cur, "result_embd", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + + struct llm_build_neo_bert : public llm_graph_context { llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -19015,6 +19191,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_MODERN_BERT: + { + llm = std::make_unique>(*this, params); + } break; case LLM_ARCH_NEO_BERT: { llm = std::make_unique(*this, params); @@ -19514,6 +19694,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DBRX: case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V3: + case LLM_ARCH_MODERN_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_STABLELM: diff --git a/src/llama-model.h b/src/llama-model.h index 10b1767f27228..c544e4f88b1d7 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -23,6 +23,7 @@ enum llm_type { LLM_TYPE_17M, LLM_TYPE_22M, LLM_TYPE_33M, + LLM_TYPE_47M, LLM_TYPE_60M, LLM_TYPE_70M, LLM_TYPE_80M, diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 8cb36661a0c96..39fa446f56dd6 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1864,7 +1864,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || tokenizer_pre == "a.x-4.0" || - tokenizer_pre == "mellum") { + tokenizer_pre == "mellum" || + tokenizer_pre == "modern-bert" ) { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "jina-v1-en" || @@ -2499,6 +2500,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { for (const auto * token : {"", "", "<|endoftext|>"}) { _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false); } + } else if ( _contains_any(model_name, {"modern-bert"})) { + if ( token_to_id.count("MASK") == 0 ) { + LLAMA_LOG_WARN("%s: Mask token missing in vocab!\n", __func__); + } + else { + _set_token_attr("[MASK]", LLAMA_TOKEN_ATTR_LSTRIP, true); + } } } }