Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed f_norm_rms_eps bug
  • Loading branch information
arki05 committed Mar 21, 2024
commit 6052e3b3a7616f6cc506771bfcf93bb21a019ad2
21 changes: 21 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,31 +93,42 @@ def set_gguf_parameters(self):

if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
self.gguf_writer.add_context_length(n_ctx)
print(f"gguf: context length = {n_ctx}")

n_embd = self.find_hparam(["hidden_size", "n_embd"])
self.gguf_writer.add_embedding_length(n_embd)
print(f"gguf: embedding length = {n_embd}")

if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
self.gguf_writer.add_feed_forward_length(n_ff)
print(f"gguf: feed forward length = {n_ff}")

n_head = self.find_hparam(["num_attention_heads", "n_head"])
self.gguf_writer.add_head_count(n_head)
print(f"gguf: head count = {n_head}")

if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
self.gguf_writer.add_head_count_kv(n_head_kv)
print(f"gguf: key-value head count = {n_head_kv}")

if (rope_theta := self.hparams.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta)
print(f"gguf: rope theta = {rope_theta}")
if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
print(f"gguf: rms norm epsilon = {f_rms_eps}")
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
print(f"gguf: layer norm epsilon = {f_norm_eps}")
if (n_experts := self.hparams.get("num_local_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
print(f"gguf: expert count = {n_experts}")
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)
print(f"gguf: experts used count = {n_experts_used}")

self.gguf_writer.add_file_type(self.ftype)
print(f"gguf: file type = {self.ftype}")

def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
Expand Down Expand Up @@ -1057,6 +1068,16 @@ class GrokModel(Model):

def set_vocab(self):
self._set_vocab_sentencepiece()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_name("Grok")




@Model.register("MiniCPMForCausalLM")
class MiniCPMModel(Model):
Expand Down
11 changes: 11 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,7 @@ enum e_model {
MODEL_40B,
MODEL_65B,
MODEL_70B,
MODEL_314B,
MODEL_SMALL,
MODEL_MEDIUM,
MODEL_LARGE,
Expand Down Expand Up @@ -3314,6 +3315,7 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_40B: return "40B";
case MODEL_65B: return "65B";
case MODEL_70B: return "70B";
case MODEL_314B: return "314B";
case MODEL_SMALL: return "0.1B";
case MODEL_MEDIUM: return "0.4B";
case MODEL_LARGE: return "0.8B";
Expand Down Expand Up @@ -3452,6 +3454,15 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_GROK:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

switch (hparams.n_layer) {
case 64: model.type = e_model::MODEL_314B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_FALCON:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
Expand Down