Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
support quant lm_head
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel committed Aug 14, 2024
commit 009f1be4d22959126dbad06662d9ba7684692bdf
173 changes: 173 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(
device=None,
use_layer_wise=False,
model_path="",
quant_lm_head=False,
dataloader=None,
*args,
**kwargs,
Expand Down Expand Up @@ -233,6 +234,7 @@ def __init__(
self.act_order_default = False
self.static_groups_default = False
self.true_sequential_default = False
self.quant_lm_head = quant_lm_head
self.perchannel_default = True
self.mse_default = False
self.use_double_quant_default = False
Expand Down Expand Up @@ -755,6 +757,175 @@ def tmp(_, inp, out):
# iteratively replace the input with output, thus layerwise quantization can continue.
self.update_blockwise_hidden_states(outs)
logger.info("------------------------------")
# 2.7.1 do the post transformer blocks quantization
do_post_transformer_quant = self.quant_lm_head
if do_post_transformer_quant:
logger.info("Quantizing post transformer layers")
# the input should be self.cache_key_arguments and self.cache_positional_arguments
sub_layers = find_layers(self.gptq_related_blocks["transformers_post"]["layer"])
sub_layers_to_quant = {}
for layer_name, layer_obj in sub_layers.items():
# filter sub_layers with included layer_names in self.weight_config
full_layer_name = self.gptq_related_blocks["transformers_post"]["name"]
# if self.weight_config.get(full_layer_name, None) == None:
if self.get_layer_config(full_layer_name) is None:
logger.warning(f"{full_layer_name} can be quantized " + "but excluded from quantization configs.")
else:
sub_layers_to_quant[full_layer_name] = layer_obj
del sub_layers
sub_layers = sub_layers_to_quant
gptq_post_block = {}

def add_batch_post(_name):
def tmp(_, inp, out):
gptq_post_block[_name].add_batch(inp[0].data, out.data)

return tmp

for layer_name in sub_layers:
full_layer_name = self.gptq_related_blocks["transformers_post"]["name"]
weight_config_this_layer = self.get_layer_config(full_layer_name)
if self.use_layer_wise: # pragma: no cover
from neural_compressor.torch.algorithms.layer_wise import load_value

full_layer_name = self.gptq_related_blocks["transformers_post"]["name"]
W = load_value(self.model, full_layer_name + ".weight", self.model_path)
else:
W = sub_layers[layer_name].weight.data.clone()

gptq_post_block[layer_name] = GPTQ(sub_layers[layer_name], W, self.device)
# gptq_for_this_block[layer_name].quantizer = Quantizer()
gptq_post_block[layer_name].quantizer.configure(
weight_config_this_layer
)
# generate the gptq quantizer
handles = [] # register handles which add inputs and outputs to gptq object
for layer_name in sub_layers:
handles.append(sub_layers[layer_name].register_forward_hook(add_batch_post(layer_name)))
for j in range(len(self.dataloader)):
if "hidden_states" in self.cache_key_arguments:
out = sub_layers[layer_name](self.cache_key_arguments["hidden_states"][j])
else:
out = sub_layers[layer_name](self.cache_positional_arguments[0][j])

# if "hidden_states" in self.cache_key_arguments:
# self.cache_key_arguments["hidden_states"] = outs[:]
# else:
# self.cache_positional_arguments[0] = outs[:]
# perform the inference process

for h in handles:
h.remove()

for layer_name in sub_layers:
full_layer_name = self.gptq_related_blocks["transformers_post"]["name"]
weight_config_this_layer = self.get_layer_config(full_layer_name)
scale, zp, Q = gptq_post_block[layer_name].fasterquant(
W,
blocksize=weight_config_this_layer["block_size"],
percdamp=weight_config_this_layer["percdamp"],
groupsize=weight_config_this_layer["group_size"],
act_order=weight_config_this_layer["act_order"],
static_groups=weight_config_this_layer["static_groups"],
)
if self.use_layer_wise: # pragma: no cover
from neural_compressor.torch.algorithms.layer_wise import (
LWQ_WORKSPACE,
clean_module_weight,
load_value,
set_module_tensor_to_device,
)

sub_layer = sub_layers[layer_name]
full_layer_name = self.gptq_related_blocks["transformers_post"]["name"]
for n, p in sub_layer.named_parameters():
param_name = full_layer_name + "." + n
if n == "weight":
set_module_tensor_to_device(self.model, param_name, self.device, Q)
else:
value = load_value(self.model, param_name, self.model_path)
set_module_tensor_to_device(self.model, param_name, self.device, value)
# sub_layer.weight.data = Q
torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt")
clean_module_weight(sub_layer)
del Q
gc.collect()
else:
sub_layers[layer_name].weight.data = Q
# save the quantization results
gptq_config[full_layer_name] = {"scale": scale}
if not weight_config_this_layer["sym"]:
gptq_config[full_layer_name]["zero"] = zp
if weight_config_this_layer["act_order"] and not weight_config_this_layer["static_groups"]:
# save perm for restoring the weights, but only when static_groups is not enabled.
gptq_config[full_layer_name]["perm"] = gptq_post_block[full_layer_name].perm
gptq_post_block[layer_name].free()

# 2.7.2 lm_head: export to compressed model
for layer_name in sub_layers:
full_layer_name = self.gptq_related_blocks["transformers_post"]["name"]
weight_config_this_layer = self.get_layer_config(full_layer_name)
gptq_scale = gptq_config[full_layer_name]["scale"]
if not weight_config_this_layer["sym"]:
gptq_zp = gptq_config[full_layer_name]["zero"]
else:
gptq_zp = None
if weight_config_this_layer["act_order"]: # save perm for restoring the weights
gptq_perm = gptq_config[full_layer_name]["perm"]
else:
gptq_perm = None
if self.use_layer_wise: # pragma: no cover
state_dict = torch.load(
LWQ_WORKSPACE + f"/{full_layer_name}.pt"
)
Q = state_dict["weight"].data
bias = state_dict["bias"] if "bias" in state_dict.keys() else None
else:
Q = sub_layers[layer_name].weight.data
if weight_config_this_layer["act_order"]:
Q.copy_(Q[:, gptq_perm])
if is_transformers_imported() and isinstance(sub_layers[layer_name], transformers.Conv1D): # pragma: no cover
Q = Q.t_().contiguous()
from .utility import quant_weight_w_scale

quant_weight_w_scale(
Q,
gptq_scale,
gptq_zp,
weight_config_this_layer["group_size"],
dtype=weight_config_this_layer["dtype"],
)
if weight_config_this_layer["act_order"]:
invperm = torch.argsort(gptq_perm)
Q.copy_(Q[:, invperm])
int_weight = Q.type(torch.int32) # copy_ is not workable for different types.
# replace module
if isinstance(sub_layers[layer_name], torch.nn.Linear):
in_features = sub_layers[layer_name].in_features
out_features = sub_layers[layer_name].out_features
elif is_transformers_imported() and isinstance(sub_layers[layer_name], transformers.Conv1D): # pragma: no cover
in_features = sub_layers[layer_name].weight.shape[0]
out_features = sub_layers[layer_name].weight.shape[1]
int_weight = sub_layers[layer_name].weight.t_().contiguous()
scale = scale.t_().contiguous()
zp = zp.t_().contiguous() if zp is not None else zp

if not self.use_layer_wise: # pragma: no cover
bias = sub_layers[layer_name].bias

new_module = INCWeightOnlyLinear(
in_features,
out_features,
dtype=weight_config_this_layer["dtype"],
bits=weight_config_this_layer["bits"],
group_size=weight_config_this_layer["group_size"],
zp=gptq_zp is not None,
bias=bias is not None,
g_idx=gptq_perm is not None,
device=self.device,
)
new_module.pack(int_weight, gptq_scale, gptq_zp, bias, gptq_perm)
set_module(self.model, layer_name, new_module)

logger.info("Quantization done")
# self.model.config.use_cache = self.use_cache
Expand Down Expand Up @@ -1169,6 +1340,7 @@ def prepare(
device=None,
use_layer_wise=False,
model_path=None,
quant_lm_head=False,
*args,
**kwargs,
):
Expand All @@ -1188,6 +1360,7 @@ def prepare(
device=device,
use_layer_wise=use_layer_wise,
model_path=model_path,
quant_lm_head=quant_lm_head,
)
self.gptq_quantizer.prepare_for_calibration()
return self.gptq_quantizer.model
Expand Down
1 change: 1 addition & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def gptq_entry(
{
"use_layer_wise": quant_config.use_layer_wise,
"model_path": quant_config.model_path,
"quant_lm_head": quant_config.quant_lm_head,
}
)
kwargs.pop("example_inputs")
Expand Down
1 change: 0 additions & 1 deletion neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,6 @@ def __init__(
white_list (Optional[List[OP_NAME_OR_MODULE_TYPE]]): White list of operator names or module types.
Default is DEFAULT_WHITE_LIST.
"""
assert not quant_lm_head, "GPTQ doesn't support lm_head quantization currently, it's coming soon!"
super().__init__(white_list=white_list)
self.dtype = dtype
self.bits = bits
Expand Down
43 changes: 39 additions & 4 deletions test/3x/torch/quantization/weight_only/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def test_act_order(self):
# compare atol, this case is an ideal case.
assert atol_false > atol_true, "act_order=True doesn't help accuracy, maybe is reasonable, please double check."

def test_layer_wise(self):
def test_layer_wise(self, quant_lm_head=False):
model = copy.deepcopy(self.tiny_gptj)
quant_config = GPTQConfig()
quant_config = GPTQConfig(quant_lm_head=quant_lm_head)
model = prepare(model, quant_config)
run_fn(model)
model = convert(model)
Expand All @@ -194,12 +194,18 @@ def test_layer_wise(self):

model = load_empty_model("hf-internal-testing/tiny-random-GPTJForCausalLM")

quant_config = GPTQConfig(use_layer_wise=True, model_path="hf-internal-testing/tiny-random-GPTJForCausalLM")
quant_config = GPTQConfig(
use_layer_wise=True,
quant_lm_head=quant_lm_head,
model_path="hf-internal-testing/tiny-random-GPTJForCausalLM")
model = prepare(model, quant_config)
run_fn(model)
model = convert(model)
out = model(self.example_inputs)[0]
assert torch.equal(out, q_label), "use_layer_wise=True output should be same. Please double check."
assert (torch.equal(out, q_label)
), f"use_layer_wise=True and quant_lm_head={quant_lm_head} output should be same. Please double check."
if not quant_lm_head:
self.test_layer_wise(quant_lm_head=True) # Avoid errors raised by @pytest.mark.parametrize

def test_true_sequential(self):
# true_sequential=False
Expand All @@ -226,6 +232,35 @@ def test_true_sequential(self):
assert (
atol_false < atol_true
), "true_sequential=True doesn't help accuracy, maybe is reasonable, please double check."

def test_quant_lm_head(self):
# quant_lm_head=False
model = copy.deepcopy(self.tiny_gptj)
quant_config = GPTQConfig(
quant_lm_head=False,
)
model = prepare(model, quant_config)
run_fn(model)
model = convert(model)
out = model(self.example_inputs)[0]
atol_false = (out - self.label).amax()
# quant_lm_head=True
model = copy.deepcopy(self.tiny_gptj)
quant_config = GPTQConfig(
quant_lm_head=True,
)
model = prepare(model, quant_config)
run_fn(model)
model = convert(model)
out = model(self.example_inputs)[0]
atol_true = (out - self.label).amax()
# compare atol, this case is an ideal case.
assert (
atol_false < atol_true
), "quant_lm_head=True doesn't help accuracy, maybe is reasonable, please double check."
assert (
get_woq_linear_num(model, "INCWeightOnlyLinear") == 31
), "Incorrect number of INCWeightOnlyLinear modules"

@pytest.mark.parametrize("dtype", ["nf4", "int4"])
@pytest.mark.parametrize("double_quant_bits", [6])
Expand Down