Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
874c341
support splits in convert.py
christianazinn Apr 27, 2024
72cbd4e
Support split by size and dry run to write estimated shards/filesizes
christianazinn Apr 28, 2024
702a744
Move split functionality to new GGUFManager class
christianazinn Apr 28, 2024
c33bdf3
fix improper function signature
christianazinn Apr 29, 2024
b7c6120
tentative push of convert-hf-to-gguf support
christianazinn May 5, 2024
14b3291
Merge branch 'master' into convert-split
mofosyne May 9, 2024
87a98a5
resolve merge + SplitArguments for easier parsing
christianazinn May 10, 2024
2dd7841
Merge remote-tracking branch 'origin' into convert-split
christianazinn May 23, 2024
3ff27ef
Fix eager tensor memory leak and remove convert.py changes
christianazinn May 23, 2024
6b5c375
refactor SplitStrategy to be a deque
christianazinn May 24, 2024
09baf2f
fix Q8 quantization
christianazinn Jun 3, 2024
240243e
remove unnecessary imports in gguf_manager
christianazinn Jun 3, 2024
140eb52
Merge branch 'master' into convert-split
christianazinn Jun 3, 2024
a9c7703
fix final? merge issue
christianazinn Jun 3, 2024
efead04
fix gguf_writer placement and remove comments
christianazinn Jun 3, 2024
c8ecbc6
oops, actually fix gguf_writer placement
christianazinn Jun 3, 2024
3e9430d
reduce duplicated code from gguf_writer
christianazinn Jun 5, 2024
f6fd3ea
further simplify GGUFManager
christianazinn Jun 5, 2024
bb5ee02
simplify even further and standardize with GGUFWriter
christianazinn Jun 5, 2024
5ad397d
reduce diffs with master
christianazinn Jun 5, 2024
ce7e698
form shards while adding tensors, SHA256 sums agree with master
christianazinn Jun 5, 2024
706bd69
re-add type hint
christianazinn Jun 6, 2024
6a05183
GGUFWriter compatibility fix
christianazinn Jun 6, 2024
3328b0a
Shard dataclass and un-negative dont_add_architecture
christianazinn Jun 6, 2024
1cbab22
type consistency in format_n_bytes_to_str
christianazinn Jun 6, 2024
2037eab
move kv keys to constants.py
christianazinn Jun 6, 2024
83e4a3f
make pathlib explicit
christianazinn Jun 6, 2024
13ffe22
base-1024 bytes to base-1000
christianazinn Jun 6, 2024
6d3a256
rename GGUFManager to GGUFWriterSplit
christianazinn Jun 7, 2024
1312e28
Update gguf-py/gguf/constants.py
christianazinn Jun 7, 2024
5f29d4a
fix convert-hf-to-gguf.py permissions
christianazinn Jun 7, 2024
0283fc1
fix line endings
christianazinn Jun 7, 2024
dc5cf5f
Update gguf-py/gguf/gguf_writer_split.py
christianazinn Jun 7, 2024
e093dfb
convert-hf : restore executable file permission
compilade Jun 7, 2024
9576965
examples/convert-legacy-llama.py: restore executable file permission
christianazinn Jun 8, 2024
c6ae1d6
reinstate original gguf package import and fix type annotation
christianazinn Jun 8, 2024
2e70fa1
attempt to appease the linter
christianazinn Jun 8, 2024
891b19c
attempt 2 to appease the linter
christianazinn Jun 8, 2024
02be0dd
attempt 3 to appease the linter
christianazinn Jun 8, 2024
f658e91
comma consistency
christianazinn Jun 8, 2024
079dfe3
Update convert-hf-to-gguf.py
christianazinn Jun 8, 2024
282e71f
edit cmd line args
christianazinn Jun 9, 2024
666bb09
Merge branch 'master' into convert-split
christianazinn Jun 9, 2024
03cc9bc
use simplification from #7827
christianazinn Jun 9, 2024
97dd416
kv/ti data are still wrong
christianazinn Jun 9, 2024
ff2dd7d
try to refactor kv data (still fails)
christianazinn Jun 9, 2024
ba1be97
fix ti data messiness
christianazinn Jun 9, 2024
69d6e7a
Merge branch 'master' into convert-split
christianazinn Jun 9, 2024
0779f2f
tidy up
christianazinn Jun 9, 2024
a234bf8
fix linting
christianazinn Jun 9, 2024
49b9fbe
actually make the linter happy
christianazinn Jun 9, 2024
0471f67
cleanup round 1
christianazinn Jun 9, 2024
5a96b8f
remove SplitStrategy, SplitArguments
christianazinn Jun 9, 2024
f7ecd99
appease linter
christianazinn Jun 9, 2024
9d7f694
fix typing and clean up
christianazinn Jun 9, 2024
0417104
fix linting
christianazinn Jun 9, 2024
70a6bc9
Update gguf-py/gguf/gguf_writer.py
christianazinn Jun 9, 2024
1e2d9cb
progress bar, fix split logic
christianazinn Jun 9, 2024
f7e7983
Update gguf-py/gguf/gguf_writer.py
christianazinn Jun 10, 2024
79bd2bf
catch oversights
christianazinn Jun 10, 2024
7eea552
Update gguf-py/gguf/gguf_writer.py
christianazinn Jun 10, 2024
99f9a24
Update gguf-py/gguf/gguf_writer.py
christianazinn Jun 10, 2024
ad02c94
Update gguf-py/gguf/gguf_writer.py
christianazinn Jun 10, 2024
c1b1a29
Update gguf-py/gguf/gguf_writer.py
christianazinn Jun 10, 2024
4550826
Update gguf-py/gguf/gguf_writer.py
christianazinn Jun 10, 2024
efa0609
swap bar orders
christianazinn Jun 10, 2024
b843445
Update gguf-py/gguf/gguf_writer.py
christianazinn Jun 10, 2024
854bd64
Update gguf-py/gguf/gguf_writer.py
christianazinn Jun 10, 2024
05b183f
compatibility fix
christianazinn Jun 10, 2024
e9895d2
Update gguf-py/gguf/gguf_writer.py
christianazinn Jun 10, 2024
4e4e376
Merge branch 'master' into convert-split
christianazinn Jun 15, 2024
163712e
Update convert-hf-to-gguf.py
mofosyne Jun 23, 2024
6e4182c
Merge branch 'master' into convert-split
christianazinn Jun 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
reduce duplicated code from gguf_writer
  • Loading branch information
christianazinn committed Jun 5, 2024
commit 3e9430df33c1c0f63087365b10aaa2284e1d4b5a
310 changes: 24 additions & 286 deletions gguf-py/gguf/gguf_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@
from .constants import (
GGMLQuantizationType,
GGUFEndian,
GGUFValueType,
Keys,
RopeScalingType,
PoolingType,
TokenType,
GGUFValueType
)
from .gguf_writer import GGUFWriter

Expand All @@ -33,7 +29,7 @@

SplitTensorsPerFile: TypeAlias = deque[tuple[os.PathLike[str], deque[tuple[str, Any]], GGUFWriter]] # [(outfile name, [(tensor name, tensor data)] for each tensor in file, filewriter)]
KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType]] # {key: (value, type)}
TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any], GGMLQuantizationType] # (tensor name, tensor data, tensor dtype), aka LazyModel
TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any], GGMLQuantizationType] # (tensor name, tensor data, tensor dtype)


class SplitStyle(IntEnum):
Expand All @@ -43,13 +39,6 @@ class SplitStyle(IntEnum):


class SplitArguments:
split: bool
dry_run: bool
small_first_shard: bool
split_max_tensors: int
split_max_size: int
split_style: SplitStyle

def __init__(self, args: Namespace = None) -> None:
self.split = args.split if args else False
self.split_max_tensors = args.split_max_tensors if args else 0
Expand Down Expand Up @@ -107,7 +96,7 @@ def __init__(self, fname_out: os.PathLike[str], model: list[TensorTempData], arc

for i, shard in enumerate(shards):
outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + shard_offset, total_shards))
self.append((outname, deque(shard), GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess)))
self.append((outname, shard, GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess)))

@staticmethod
def get_tensor_size(tensor) -> int:
Expand Down Expand Up @@ -146,35 +135,34 @@ def format_n_bytes_to_str(num: int) -> str:
num /= 1024.0
return f"{num:.1f}T - over 1TB, --split recommended"


# ideally this has most of the same signatures as GGUFWriter so it's nearly a drop-in replacement
class GGUFManager:
# TODO fall back to normal GGUFWriter in convert-hf-to-gguf.py if no --split
class GGUFManager(GGUFWriter):
kv_data: KVTempData
tensors: deque[TensorTempData]
tensors: list[TensorTempData]
split_arguments: SplitArguments
split_strategy: SplitStrategy
dtype: GGMLQuantizationType

def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments,
use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE
) -> None:
# TODO be able to use superclass constructor
# super().__init__(path, arch, use_temp_file=use_temp_file, endianess=endianess)
self.arch = arch
self.path = path
self.endianess = endianess
self.offset_tensor = 0
self.kv_data = {}
self.tensors = deque()
self.tensors = []
# TODO how many of these do you need
self.split_strategy = None
self.total_shards = None
self.total_tensors = None
self.use_temp_file = use_temp_file
self.split_arguments = split_arguments

self.recent_key = None
self.add_architecture()

# have to consolidate because we need to know kv data count and tensor count before we can write the header
# and we need to write tensor info before we can write metadata
# these all kinda show up around the same places anyway so it's not a huge deal?
# TODO split back into write_header_to_file, write_kv_data_to_file, write_ti_data_to_file
def write_to_file(self, meta_only: bool = False) -> None:

# here is the first place you can assume you have all tensors written and you can establish the size of the file - so logic goes here
Expand Down Expand Up @@ -232,11 +220,12 @@ def write_to_file(self, meta_only: bool = False) -> None:
while True:
try:
(_, tensors, writer) = self.split_strategy.popleft()
tensors = deque(tensors) if tensors else None
except IndexError:
break

shard_num_tensors = len(tensors) if tensors else 0

if tensors:
while True:
try:
Expand All @@ -254,44 +243,16 @@ def write_to_file(self, meta_only: bool = False) -> None:
ct = ct + 1
del tensors

def add_uint8(self, key: str, val: int) -> None:
self.kv_data[key] = (val, GGUFValueType.UINT8)

def add_int8(self, key: str, val: int) -> None:
self.kv_data[key] = (val, GGUFValueType.INT8)

def add_uint16(self, key: str, val: int) -> None:
self.kv_data[key] = (val, GGUFValueType.UINT16)

def add_int16(self, key: str, val: int) -> None:
self.kv_data[key] = (val, GGUFValueType.INT16)

def add_uint32(self, key: str, val: int) -> None:
self.kv_data[key] = (val, GGUFValueType.UINT32)

def add_int32(self, key: str, val: int) -> None:
self.kv_data[key] = (val, GGUFValueType.INT32)

def add_float32(self, key: str, val: float) -> None:
self.kv_data[key] = (val, GGUFValueType.FLOAT32)

def add_uint64(self, key: str, val: int) -> None:
self.kv_data[key] = (val, GGUFValueType.UINT64)

def add_int64(self, key: str, val: int) -> None:
self.kv_data[key] = (val, GGUFValueType.INT64)

def add_float64(self, key: str, val: float) -> None:
self.kv_data[key] = (val, GGUFValueType.FLOAT64)

def add_bool(self, key: str, val: bool) -> None:
self.kv_data[key] = (val, GGUFValueType.BOOL)

def add_string(self, key: str, val: str) -> None:
if not val:
return
self.kv_data[key] = (val, GGUFValueType.STRING)
# override add_key, add_val to handle kv data separately
def add_key(self, key: str) -> None:
self.recent_key = key

def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None:
if self.recent_key is None:
raise ValueError("No key set for value")
self.kv_data[self.recent_key] = (val, vtype)

# need to handle arrays separately
def add_array(self, key: str, val: Sequence[Any]) -> None:
if not isinstance(val, Sequence):
raise ValueError(f'Expected a sequence for {key}, got {type(val)}')
Expand All @@ -303,231 +264,8 @@ def add_tensor(
) -> None:
if self.endianess == GGUFEndian.BIG:
tensor.byteswap(inplace=True)

# TODO reimplement temp file
# I'm pretty sure it gets handled per shard?

self.tensors.append((name, tensor, raw_dtype))

def close(self) -> None:
for _, _, writer in self.split_strategy:
writer.close()

def add_architecture(self) -> None:
self.add_string(Keys.General.ARCHITECTURE, self.arch)

def add_author(self, author: str) -> None:
self.add_string(Keys.General.AUTHOR, author)

def add_version(self, version: str) -> None:
self.add_string(Keys.General.VERSION, version)

def add_tensor_data_layout(self, layout: str) -> None:
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)

def add_url(self, url: str) -> None:
self.add_string(Keys.General.URL, url)

def add_description(self, description: str) -> None:
self.add_string(Keys.General.DESCRIPTION, description)

def add_licence(self, licence: str) -> None:
self.add_string(Keys.General.LICENSE, licence)

def add_source_url(self, url: str) -> None:
self.add_string(Keys.General.SOURCE_URL, url)

def add_source_hf_repo(self, repo: str) -> None:
self.add_string(Keys.General.SOURCE_HF_REPO, repo)

def add_file_type(self, ftype: int) -> None:
self.add_uint32(Keys.General.FILE_TYPE, ftype)

def add_name(self, name: str) -> None:
self.add_string(Keys.General.NAME, name)

def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None:
self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)

def add_custom_alignment(self, alignment: int) -> None:
self.data_alignment = alignment
self.add_uint32(Keys.General.ALIGNMENT, alignment)

def add_vocab_size(self, size: int) -> None:
self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)

def add_context_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length)

def add_embedding_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)

def add_block_count(self, length: int) -> None:
self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)

def add_feed_forward_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)

def add_parallel_residual(self, use: bool) -> None:
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)

def add_head_count(self, count: int) -> None:
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)

def add_head_count_kv(self, count: int) -> None:
self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)

def add_key_length(self, length: int) -> None:
self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)

def add_value_length(self, length: int) -> None:
self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)

def add_max_alibi_bias(self, bias: float) -> None:
self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)

def add_clamp_kqv(self, value: float) -> None:
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)

def add_logit_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)

def add_expert_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)

def add_expert_used_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)

def add_layer_norm_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)

def add_layer_norm_rms_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)

def add_causal_attention(self, value: bool) -> None:
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)

def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)

def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)

def add_rope_freq_base(self, value: float) -> None:
self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)

def add_rope_scaling_type(self, value: RopeScalingType) -> None:
self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value)

def add_rope_scaling_factor(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)

def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)

def add_rope_scaling_finetuned(self, value: bool) -> None:
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)

def add_ssm_conv_kernel(self, value: int) -> None:
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)

def add_ssm_inner_size(self, value: int) -> None:
self.add_uint32(Keys.SSM.INNER_SIZE.format(arch=self.arch), value)

def add_ssm_state_size(self, value: int) -> None:
self.add_uint32(Keys.SSM.STATE_SIZE.format(arch=self.arch), value)

def add_ssm_time_step_rank(self, value: int) -> None:
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)

def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model)

def add_tokenizer_pre(self, pre: str) -> None:
self.add_string(Keys.Tokenizer.PRE, pre)

def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
self.add_array(Keys.Tokenizer.LIST, tokens)

def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
self.add_array(Keys.Tokenizer.MERGES, merges)

def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)

def add_token_type_count(self, value: int) -> None:
self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)

def add_token_scores(self, scores: Sequence[float]) -> None:
self.add_array(Keys.Tokenizer.SCORES, scores)

def add_bos_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.BOS_ID, id)

def add_eos_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.EOS_ID, id)

def add_unk_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.UNK_ID, id)

def add_sep_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.SEP_ID, id)

def add_pad_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.PAD_ID, id)

def add_cls_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.CLS_ID, id)

def add_mask_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.MASK_ID, id)

def add_add_bos_token(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_BOS, value)

def add_add_eos_token(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_EOS, value)

def add_add_space_prefix(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)

def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
if isinstance(value, list):
template_default = None
template_names = set()

for choice in value:
name = choice.get('name', '')
template = choice.get('template')

# Allowing non-alphanumerical characters in template name is probably not a good idea, so filter it
name = ''.join((c if c in ascii_letters + digits else '_' for c in name))

if name and template is not None:
if name == 'default':
template_default = template
else:
template_names.add(name)
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE_N.format(name=name), template)

if template_names:
self.add_array(Keys.Tokenizer.CHAT_TEMPLATES, list(template_names))

if template_default is None:
return

value = template_default

self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)

def add_prefix_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.PREFIX_ID, id)

def add_suffix_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.SUFFIX_ID, id)

def add_middle_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.MIDDLE_ID, id)

def add_eot_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.EOT_ID, id)
writer.close()