-
Notifications
You must be signed in to change notification settings - Fork 226
Shared emb_tokens/lm_head on nibbled 4bit qweights #1854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…im for reshape of packed weights
@microsoft-github-policy-service agree company="Microsoft" |
| ) | ||
|
|
||
| # Allow extra_options to override use_packed_matmul | ||
| if "unpack_matmul" in extra_options: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an optimization opportunity that should be auto-detected by the model builder. We should not need to give the responsibility to the user. You can see the review comments on this PR for more details.
|
|
||
| elif quant_method in {"k_quant_mixed", "k_quant_last"}: | ||
| elif quant_method in {"k_quant", "k_quant_mixed", "k_quant_last"}: | ||
| from onnxruntime.quantization.matmul_nbits_quantizer import KQuantWeightOnlyQuantConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this import up. It was previously here because it was not part of a stable release.
onnxruntime-genai/src/python/py/models/builder.py
Lines 24 to 28 in d4eabac
| from onnxruntime.quantization.matmul_nbits_quantizer import ( | |
| MatMulNBitsQuantizer, | |
| QuantFormat, | |
| RTNWeightOnlyQuantConfig, | |
| ) |
|
|
||
| if quant_method == "rtn": | ||
| int4_algo_config = RTNWeightOnlyQuantConfig() | ||
| if quant_method in {"rtn", "rtn_last"}: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be simplified to the following.
if quant_method in {"rtn", "rtn_last"}:
if quant_method == "rtn_last":
customized_weight_config["/lm_head/MatMul"] = {"bits": 8}
int4_algo_config = RTNWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)| int4_algo_config = RTNWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) | ||
|
|
||
| elif quant_method in {"k_quant_mixed", "k_quant_last"}: | ||
| elif quant_method in {"k_quant", "k_quant_mixed", "k_quant_last"}: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be simplified to the following.
elif quant_method in {"k_quant", "k_quant_mixed", "k_quant_last"}:
if quant_method != "k_quant":
customized_weight_config["/lm_head/MatMul"] = {"bits": 8}
if quant_method == "k_quant_mixed":
# k_quant_mixed is from llama.cpp.
# Reference: https://github.com/ggml-org/llama.cpp/blob/36667c8edcded08063ed51c7d57e9e086bbfc903/src/llama-quant.cpp#L136
# We also consider some MatMuls are more senstive to quantization than other MatMuls.
layers_to_exclude = [
i
for i in range(self.num_layers)
if i < self.num_layers / 8 or i >= 7 * self.num_layers / 8 or (i - (round)(self.num_layers / 8)) % 3 == 2
]
for i in layers_to_exclude:
customized_weight_config["/model/layers." + str(i) + "/attn/qkv_proj/MatMul"] = {"bits": 8}
customized_weight_config["/model/layers." + str(i) + "/attn/v_proj/MatMul"] = {"bits": 8}
customized_weight_config["/model/layers." + str(i) + "/mlp/down_proj/MatMul"] = {"bits": 8}
int4_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)| self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last"} | ||
| if not self.int8_lm_head: | ||
| self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last", "rtn_last"} | ||
| if not self.int8_lm_head and extra_options.get("int4_algo_config", "default") not in {"rtn", "k_quant"}: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rewrite the above section and the if condition to just match on the conditions needed for tied embeddings to be true and otherwise set it to false?
Something like this:
self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last", "rtn_last"}
self.int4_tied_embeddings = extra_options.get("int4_tied_embeddings", config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False)
# matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match.
# tied_embeddings lm_head.MatMul.weight_Q{}G{} only works with rtn&k_quant on 4bit
self.int4_tied_embeddings = <boolean expression>|
Can you update the options for onnxruntime-genai/src/python/py/models/builder.py Lines 4685 to 4688 in d4eabac
|
Problem
The current model builder doesn't support shared embeddings layers with 4bit qweights, which occupies more room in disk and hurts compression rate. builder.py doesn't provide flexible option to toggle the graph construction and quantization config, like unpacked/packed matmul, rtn, kquant, etc.
Solution
Calculated flat_dim in a more generic way on reshape node before
GatherBlockQuantized(support 4bit and 8bit).Added CUDA kernel support in ORT #26484.
Added more extra_options to enable different quant configs and pack options.
Running examples:
unpacked qkv_projs and shared 4 bit RTN on Llama3.2 1B Instruct:
shared 4 bit k_quant on Phi-4-Mini Instruct:
Changes
Modified Files
src/python/py/models/builder.pyKey Modifications
flat_dimin a generic manner before feeding inGatherBlockQuantized.unpack_matmuloption to separate qvk_proj if needed.rtn_lastlikek_quant_lastas a new mixed precision optionk_quantlikertnas a new 4 bit quantizer option