Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6151592
constants and tensor mappings for modern bert support, model not supp…
ryan-mangeno Aug 21, 2025
6643c5a
conversion now working, hf -> gguf
ryan-mangeno Aug 21, 2025
ac67fc6
working on support, now working on building graph
ryan-mangeno Aug 25, 2025
cc40378
some cleanup
ryan-mangeno Aug 25, 2025
41b6864
cleanup
ryan-mangeno Aug 26, 2025
cc3d7ab
continuing
ryan-mangeno Aug 26, 2025
4ceb828
correct tensor shape for qkv
ryan-mangeno Aug 26, 2025
18c0c23
fixed tensor mappings and working on buildin graph
ryan-mangeno Aug 27, 2025
bffe3c9
tensor debugging now works -> (llama-eval-callback), instead of simul…
ryan-mangeno Aug 28, 2025
8f32843
cleanup
ryan-mangeno Aug 28, 2025
9805635
cleanup
ryan-mangeno Aug 28, 2025
40249dd
cleanup
ryan-mangeno Aug 28, 2025
853f344
more cleanup
ryan-mangeno Aug 28, 2025
2a1c750
ubatch issues, the assert for checking equal seqs in llama-graph.cpp …
ryan-mangeno Aug 28, 2025
c73eb68
added cls token per previous modern bert attempt, still working on ch…
ryan-mangeno Aug 29, 2025
ca353d3
fixed pre tokenizer and still working through previous pr
ryan-mangeno Sep 2, 2025
6d86944
working through previous attemp, implimented more accurate conversion…
ryan-mangeno Sep 3, 2025
39c0291
fixed pre tokenizer
ryan-mangeno Sep 3, 2025
e101005
working on swa with local and global alternating attention
ryan-mangeno Sep 8, 2025
044bc7d
some cleanup and now fails on build attn
ryan-mangeno Sep 8, 2025
e296a0b
starting to work, and some cleanup, currently failing on last layer c…
ryan-mangeno Sep 8, 2025
2bacfb0
alternating rope implemented and modern bert graph build succeeds
ryan-mangeno Sep 11, 2025
4e7c879
fixed asser for equal ubatch seq
ryan-mangeno Sep 11, 2025
20d448a
cleanup
ryan-mangeno Sep 11, 2025
db4f565
added mask check in vocab
ryan-mangeno Sep 12, 2025
da0604a
fixed alternating rope, the hparams.rope_freq_base_train and hparams.…
ryan-mangeno Sep 12, 2025
43a2980
reuse variable
ryan-mangeno Sep 13, 2025
e368442
fixed merge conflicts and added print debug check for swa type
ryan-mangeno Sep 13, 2025
7036cc8
removed repeat
ryan-mangeno Sep 13, 2025
2522ce8
merge fixes
ryan-mangeno Sep 14, 2025
e043815
Merge branch 'master' into modern-bert-support
ryan-mangeno Sep 15, 2025
35667f2
Merge branch 'master' into modern-bert-support
ryan-mangeno Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
res = "mellum"
if chkhsh == "a0b64b4385f123663873756336c085744376d015ff328bb1d901598f63c44152":
# ref: https://huggingface.co/ibm-granite/granite-embedding-small-english-r2
res = "modern-bert"
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
res = "llada-moe"
Expand Down Expand Up @@ -8733,6 +8736,35 @@ def prepare_tensors(self):
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("ModernBertModel", "ModernBertForMaskedLM", "ModernBertForSequenceClassification")
class ModernBertModel(BertModel):
model_arch = gguf.MODEL_ARCH.MODERN_BERT

def set_vocab(self):
self._set_vocab_gpt2()
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_sliding_window(self.hparams["local_attention"])
self.gguf_writer.add_rope_freq_base(self.hparams["global_rope_theta"])
self.gguf_writer.add_rope_freq_base_swa(self.hparams["local_rope_theta"])
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# These layers act as MLM head, so we don't need them
if name.startswith("decoder."):
return []

if name.startswith("model."):
name = name[6:]

return super().modify_tensors(data_torch, name, bid)



class MistralModel(LlamaModel):
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
{"name": "modern-bert", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-embedding-small-english-r2", },
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
]

Expand Down
17 changes: 17 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
FREQ_BASE = "{arch}.rope.freq_base"
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
Expand Down Expand Up @@ -322,6 +323,7 @@ class MODEL_ARCH(IntEnum):
STARCODER = auto()
REFACT = auto()
BERT = auto()
MODERN_BERT = auto()
NOMIC_BERT = auto()
NOMIC_BERT_MOE = auto()
NEO_BERT = auto()
Expand Down Expand Up @@ -658,6 +660,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.STARCODER: "starcoder",
MODEL_ARCH.REFACT: "refact",
MODEL_ARCH.BERT: "bert",
MODEL_ARCH.MODERN_BERT: "modern-bert",
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
MODEL_ARCH.NEO_BERT: "neo-bert",
Expand Down Expand Up @@ -1194,6 +1197,20 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.CLS,
MODEL_TENSOR.CLS_OUT,
],
MODEL_ARCH.MODERN_BERT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.CLS,
MODEL_TENSOR.CLS_OUT,
],
MODEL_ARCH.NOMIC_BERT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,9 @@ def add_iclr_lora_rank(self, length: int) -> None:
def add_value_residual_mix_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)

def add_rope_freq_base_swa(self, value: float) -> None:
self.add_float32(Keys.Rope.FREQ_BASE_SWA.format(arch=self.arch), value)

def add_gate_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)

Expand Down
11 changes: 10 additions & 1 deletion gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TensorNameMap:
"embed_tokens", # embeddinggemma
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert
"embeddings.tok_embeddings", # modern bert
"language_model.embedding.word_embeddings", # persimmon
"wte", # gpt2
"transformer.embd.wte", # phi2
Expand Down Expand Up @@ -46,6 +47,7 @@ class TensorNameMap:
MODEL_TENSOR.TOKEN_EMBD_NORM: (
"word_embeddings_layernorm", # bloom
"embeddings.LayerNorm", # bert
"embeddings.norm", # modern bert
"emb_ln", # nomic-bert
"transformer.norm", # openelm
"rwkv.blocks.0.pre_ln", # rwkv
Expand Down Expand Up @@ -99,6 +101,7 @@ class TensorNameMap:
"backbone.final_layer_norm", # wavtokenizer
"model.norm", # llama4
"model.transformer.ln_f", # llada
"final_norm", # modern bert
),

# Rope frequencies
Expand Down Expand Up @@ -145,9 +148,10 @@ class TensorNameMap:
"model.layers.{bid}.input_layernorm", # llama4
"layers.{bid}.input_layernorm", # embeddinggemma
"transformer_encoder.{bid}.attention_norm", # neobert
"layers.{bid}.attn_norm", # bert
"model.layers.{bid}.operator_norm", # lfm2
"model.transformer.blocks.{bid}.attn_norm", # llada
"layers.{bid}.input_layernorm", # qwen3-embedding
"layers.{bid}.input_layernorm", # qwen3-embedding,
),

# Attention norm 2
Expand Down Expand Up @@ -177,6 +181,7 @@ class TensorNameMap:
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
"transformer.layers.{bid}.attn.qkv_proj", # openelm
"transformer_encoder.{bid}.qkv", # neobert
"layers.{bid}.attn.Wqkv", # modern bert
),

# Attention query
Expand Down Expand Up @@ -250,6 +255,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.linear_attn", # deci
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"layers.{bid}.attn.Wo", # modern bert
"transformer.layer.{bid}.attention.out_lin", # distillbert
"transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
Expand Down Expand Up @@ -325,6 +331,7 @@ class TensorNameMap:
"model.layers.layers.{bid}.pre_mlp_norm", # plamo2
"model.transformer.blocks.{bid}.ff_norm", # llada
"layers.{bid}.post_attention_layernorm", # qwen3-embedding
"layers.{bid}.mlp_norm" # modern bert
),

# Post feed-forward norm
Expand Down Expand Up @@ -378,6 +385,7 @@ class TensorNameMap:
"layers.{bid}.mlp.up_proj", # embeddinggemma
"layers.{bid}.feed_forward.w3", # llama-pth
"encoder.layer.{bid}.intermediate.dense", # bert
"layers.{bid}.mlp.Wi", # modern bert
"transformer.layer.{bid}.ffn.lin1", # distillbert
"transformer.h.{bid}.mlp.fc_in", # gpt-j
"transformer.h.{bid}.mlp.linear_3", # refact
Expand Down Expand Up @@ -479,6 +487,7 @@ class TensorNameMap:
"layers.{bid}.mlp.down_proj", # embeddinggemma
"layers.{bid}.feed_forward.w2", # llama-pth
"encoder.layer.{bid}.output.dense", # bert
"layers.{bid}.mlp.Wo", # modern bert
"transformer.layer.{bid}.ffn.lin2", # distillbert
"transformer.h.{bid}.mlp.fc_out", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
Expand Down
20 changes: 20 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_STARCODER, "starcoder" },
{ LLM_ARCH_REFACT, "refact" },
{ LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_MODERN_BERT, "modern-bert" },
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_NEO_BERT, "neo-bert" },
Expand Down Expand Up @@ -176,9 +177,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },


{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
Expand Down Expand Up @@ -525,6 +528,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_CLS_OUT, "cls.output" },
},
},
{
LLM_ARCH_MODERN_BERT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_CLS, "cls" },
{ LLM_TENSOR_CLS_OUT, "cls.output" },
},
},
{
LLM_ARCH_NOMIC_BERT,
{
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum llm_arch {
LLM_ARCH_STARCODER,
LLM_ARCH_REFACT,
LLM_ARCH_BERT,
LLM_ARCH_MODERN_BERT,
LLM_ARCH_NOMIC_BERT,
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_NEO_BERT,
Expand Down Expand Up @@ -184,6 +185,7 @@ enum llm_kv {
LLM_KV_ROPE_DIMENSION_SECTIONS,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_ROPE_FREQ_BASE_SWA,
LLM_KV_ROPE_SCALING_TYPE,
LLM_KV_ROPE_SCALING_FACTOR,
LLM_KV_ROPE_SCALING_ATTN_FACTOR,
Expand Down
4 changes: 2 additions & 2 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_s
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" :
(swa_type == LLAMA_SWA_TYPE_LOCAL) ? "LLAMA_SWA_TYPE_LOCAL" : "unknown";
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
Expand Down Expand Up @@ -1578,7 +1579,6 @@ ggml_tensor * llm_graph_context::build_attn(
// optionally store to KV cache
if (k_cur) {
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();

ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
}

Expand Down
1 change: 1 addition & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ enum llama_swa_type {
LLAMA_SWA_TYPE_STANDARD = 1,
LLAMA_SWA_TYPE_CHUNKED = 2,
LLAMA_SWA_TYPE_SYMMETRIC = 3,
LLAMA_SWA_TYPE_LOCAL = 4,
};

struct llama_hparams_posnet {
Expand Down
Loading
Loading