Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
4cf0620
support rtn & gptq(draft)
Kaihui-intel Jun 25, 2024
a1d9e10
clean code
Kaihui-intel Jun 25, 2024
b4e93f3
clean gptq
Kaihui-intel Jun 25, 2024
a3a061e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 25, 2024
02ee1f8
del unused line
Kaihui-intel Jun 25, 2024
060ea50
fix load import
Kaihui-intel Jun 26, 2024
1a60731
fix rtn model_path
Kaihui-intel Jun 26, 2024
04e1923
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2024
8f27d47
update rtn model
Kaihui-intel Jun 26, 2024
263c581
Merge branch 'kaihui/lw' of https://github.com/intel/neural-compresso…
Kaihui-intel Jun 26, 2024
5a3f090
fix clean module
Kaihui-intel Jun 26, 2024
14bd733
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2024
4ce74db
fix layerwise woq forward
Kaihui-intel Jun 26, 2024
199fe4c
Merge branch 'kaihui/lw' of https://github.com/intel/neural-compresso…
Kaihui-intel Jun 26, 2024
b700d39
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2024
96d0e05
fix import
Kaihui-intel Jun 26, 2024
4337eac
Merge branch 'kaihui/lw' of https://github.com/intel/neural-compresso…
Kaihui-intel Jun 26, 2024
7b2d326
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2024
77cde5c
update clean module & add timestep
Jul 3, 2024
6cf8ff3
add numba pack
Jul 11, 2024
0e388c0
mimor fix numba
Jul 11, 2024
b0ccd62
apply mask
Jul 11, 2024
0f7de68
support gptq
Jul 11, 2024
83c6a9b
keep q_model in memory
Jul 12, 2024
483c219
merge master
Jul 12, 2024
c543783
fix master conflict
Jul 12, 2024
159aa34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2024
809c0fb
update numba requirements_pt
Jul 12, 2024
308c7fc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2024
5d80e9b
fix awq config
Jul 12, 2024
c4af344
remove pack_with_reshpe
Jul 12, 2024
e99ee19
recover ar
Jul 12, 2024
1dd01a0
revert eg
Jul 12, 2024
8dbf793
install py 3x deps
chensuyue Jul 12, 2024
0ea77fd
enhance import&add pack ut
Kaihui-intel Jul 16, 2024
eec87ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2024
36a4a29
add pack ut file
Kaihui-intel Jul 16, 2024
86008f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2024
93a86f2
move load_empty_model to torch.utils
Kaihui-intel Jul 16, 2024
19b1c4d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2024
f17c640
remove torch import
Kaihui-intel Jul 16, 2024
fa39f6f
fix ut import
Kaihui-intel Jul 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
support rtn & gptq(draft)
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel committed Jun 25, 2024
commit 4cf0620c6e4e4ccc5a26d5a1b72afd6d9d73156d
2 changes: 2 additions & 0 deletions neural_compressor/torch/algorithms/layer_wise/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def _get_path(pretrained_model_name_or_path):
path = dowload_hf_model(pretrained_model_name_or_path)
return path

get_path = _get_path


def load_value(model, param_name, path):
if "lm_head" in param_name and getattr(model.config, "tie_word_embeddings", True):
Expand Down
130 changes: 80 additions & 50 deletions neural_compressor/torch/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,13 @@ def __init__(

# device
self.device = get_accelerator(kwargs.pop("device", "auto")).current_device_name()
self.model.to(self.device)
if not use_layer_wise:
self.model.to(self.device)
self.is_ready = False

self.use_layer_wise = use_layer_wise
self.model_path = model_path
if use_layer_wise:
self.prepare_layer_wise(model_path)

# dataloader
self.use_max_length = use_max_length
Expand All @@ -237,6 +239,18 @@ def __init__(
self.dataloader = []
self.nsamples = nsamples

def prepare_layer_wise(self, model_path):
from neural_compressor.torch.algorithms.layer_wise import LWQ_WORKSPACE, get_path, register_weight_hooks
import os
os.makedirs(LWQ_WORKSPACE, exist_ok=True)
if model_path == "":
model_path = self.model.path
assert model_path, "model_path should not be None."
self.model_path = get_path(model_path)
register_weight_hooks(
self.model, self.model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE
)

def get_full_layer_name(self, sub_layer_name, block_idx):
transformer_name = self.gptq_related_blocks["transformers_name"]
return ".".join([transformer_name, str(block_idx), sub_layer_name])
Expand Down Expand Up @@ -394,7 +408,6 @@ def execute_quantization(self, means=None, stds=None):
# Step1: prepare quantization (calibration datasets)

logger.info("Begin ====>")
model_path = self.model_path

# Step2: run gptq quantization in a transformer block-wise manner.
gptq_config = {}
Expand Down Expand Up @@ -430,8 +443,8 @@ def execute_quantization(self, means=None, stds=None):
weight_config_this_layer = self.get_layer_config(full_layer_name)
if self.use_layer_wise: # pragma: no cover
from neural_compressor.torch.algorithms.layer_wise import load_value

W = load_value(self.model, full_layer_name + ".weight", model_path)
# import pdb; pdb.set_trace()
W = load_value(self.model, full_layer_name + ".weight", self.model_path)
else:
W = sub_layers[layer_name].weight.data.clone()

Expand Down Expand Up @@ -467,12 +480,23 @@ def tmp(_, inp, out):
weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx))
logger.info(f"Quantizing layer {layer_name}")
if self.use_layer_wise: # pragma: no cover
from neural_compressor.torch.algorithms.layer_wise import load_value
from neural_compressor.torch.algorithms.layer_wise import load_value, set_module_tensor_to_device

full_layer_name = self.get_full_layer_name(layer_name, block_idx)
W = load_value(self.model, full_layer_name + ".weight", model_path)
for n, p in sub_layers[layer_name].named_parameters():
param_name = full_layer_name + "." + n
# breakpoint()
if n == "weight":
W = load_value(self.model, full_layer_name + ".weight", self.model_path)
else:
value = load_value(self.model, param_name, self.model_path)
set_module_tensor_to_device(self.model, param_name, self.device, value)

else:
W = sub_layers[layer_name].weight.data.clone()



accelerator.mark_step()
if "hpu" in self.device:
W = W.to("cpu")
Expand All @@ -484,55 +508,16 @@ def tmp(_, inp, out):
act_order=weight_config_this_layer["act_order"],
static_groups=weight_config_this_layer["static_groups"],
)
if self.use_layer_wise: # pragma: no cover
from neural_compressor.torch.algorithms.layer_wise import (
LWQ_WORKSPACE,
clean_module_weight,
load_value,
set_module_tensor_to_device,
)

sub_layer = sub_layers[layer_name]
full_layer_name = self.get_full_layer_name(layer_name, block_idx)
for n, p in sub_layer.named_parameters():
param_name = full_layer_name + "." + n
if n == "weight":
set_module_tensor_to_device(self.model, param_name, self.device, Q)
else:
value = load_value(self.model, param_name, model_path)
set_module_tensor_to_device(self.model, param_name, self.device, value)
# sub_layer.weight.data = Q
torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt")
clean_module_weight(sub_layer)
del Q
gc.collect()
else:
sub_layers[layer_name].weight.data = Q

# Step 2.5: export to compressed model
gptq_config[self.get_full_layer_name(layer_name, block_idx)] = {"scale": scale}
if not weight_config_this_layer["sym"]:
gptq_config[self.get_full_layer_name(layer_name, block_idx)]["zero"] = zp
if weight_config_this_layer["act_order"]: # save perm for restoring the weights
gptq_config[self.get_full_layer_name(layer_name, block_idx)]["perm"] = gptq_for_this_block[
layer_name
].perm
gptq_for_this_block[layer_name].free()

# Step 2.5: replace output data with quantized weights
outs = []
batch_num = self.cache_key_arguments.pop("batch_num")
for j in range(batch_num):
cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j)
cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j)
out = transformer_block(*cache_positional_batch, **cache_keyword_batch)
out = self.track_hidden_states(out)
outs.append(out)
self.cache_key_arguments["batch_num"] = batch_num
if self.use_layer_wise: # pragma: no cover
self.gptq_related_blocks["transformers"][block_idx] = transformer_block
else:
self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu()
# Step 2.6: export to compressed model
for layer_name in sub_layers:

weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx))
gptq_scale = gptq_config[self.get_full_layer_name(layer_name, block_idx)]["scale"]
if not weight_config_this_layer["sym"]:
Expand All @@ -543,7 +528,6 @@ def tmp(_, inp, out):
gptq_perm = gptq_config[self.get_full_layer_name(layer_name, block_idx)]["perm"]
else:
gptq_perm = None
Q = sub_layers[layer_name].weight.data
if weight_config_this_layer["act_order"]:
Q.copy_(Q[:, gptq_perm])
if is_transformers_imported() and isinstance(sub_layers[layer_name], transformers.Conv1D):
Expand Down Expand Up @@ -584,7 +568,52 @@ def tmp(_, inp, out):
device=self.device,
)
new_module.pack(int_weight, gptq_scale, gptq_zp, sub_layers[layer_name].bias, gptq_perm)


if self.use_layer_wise: # pragma: no cover
from neural_compressor.torch.algorithms.layer_wise import (
LWQ_WORKSPACE,
clean_module_weight,
load_value,
set_module_tensor_to_device,
)

# sub_layer = sub_layers[layer_name]
# full_layer_name = self.get_full_layer_name(layer_name, block_idx)
# for n, p in sub_layer.named_parameters():
# param_name = full_layer_name + "." + n
# # breakpoint()
# if n == "weight":
# set_module_tensor_to_device(self.model, param_name, self.device, Q)
# else:
# value = load_value(self.model, param_name, model_path)
# set_module_tensor_to_device(self.model, param_name, self.device, value)
# sub_layer.weight.data = Q
# torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt")
torch.save(new_module.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt")
clean_module_weight(new_module)
del Q
gc.collect()
set_module(transformer_block, layer_name, new_module)

gptq_for_this_block[layer_name].free()

# Step 2.6: replace output data with quantized weights
outs = []
batch_num = self.cache_key_arguments.pop("batch_num")
for j in range(batch_num):
cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j)
cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j)
out = transformer_block(*cache_positional_batch, **cache_keyword_batch)
out = self.track_hidden_states(out)
outs.append(out)
self.cache_key_arguments["batch_num"] = batch_num
if self.use_layer_wise: # pragma: no cover
self.gptq_related_blocks["transformers"][block_idx] = transformer_block
else:
self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu()


del gptq_for_this_block
torch.cuda.empty_cache()
# iteratively replace the input with output, thus layerwise quantization can continue.
Expand Down Expand Up @@ -999,6 +1028,7 @@ def prepare(
def convert(self, model, *args, **kwargs):
self.gptq_quantizer.model = model
self.gptq_quantizer.remove_prepare_for_calibration()

q_model, gptq_config = self.gptq_quantizer.execute_quantization()
q_model.gptq_config = gptq_config
logger.info("GPTQ quantizing done.")
Expand Down
3 changes: 2 additions & 1 deletion neural_compressor/torch/algorithms/weight_only/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
self.scales = self.scales.T.contiguous()
self.qweight = self.qweight.T.contiguous()
self.qzeros = self.qzeros.T.contiguous()
int_weight = int_weight.to(self.device)
if int_weight.device.type != "meta":
int_weight = int_weight.to(self.device)
if self.use_optimum_format and zp is None:
# to avoid overflow
int_weight = int_weight.type(torch.int32)
Expand Down
40 changes: 36 additions & 4 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from collections import OrderedDict

import torch
import gc

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger, set_module
Expand Down Expand Up @@ -89,10 +90,6 @@ def convert(
weight_config = self.quant_config
device = get_accelerator(kwargs.pop("device", "auto")).current_device_name()

# Put model on device explicitly
# TODO: refine it later, Put module on device one by one instead of the whole model
model.to(device)

assert isinstance(model, torch.nn.Module), "only support torch module"
if is_transformers_imported():
supported_layers = (torch.nn.Linear, transformers.Conv1D)
Expand Down Expand Up @@ -130,6 +127,7 @@ def convert(
use_full_range = weight_config[name]["use_full_range"]
use_mse_search = weight_config[name]["use_mse_search"]
use_layer_wise = weight_config[name]["use_layer_wise"]
model_path = weight_config[name]["model_path"]
use_optimum_format = kwargs.get("use_optimum_format", True)
# double quant config
double_quant_config = {
Expand All @@ -154,6 +152,24 @@ def convert(
continue
logger.debug(f"RTN quantized module:{name, m}")
logger.debug(log_msg)

if use_layer_wise:
from neural_compressor.common.utils import DEFAULT_WORKSPACE
from neural_compressor.torch.algorithms.layer_wise.utils import get_path, load_module, load_value
import os
lwq_workspace = os.path.join(DEFAULT_WORKSPACE, "lwq_tmpdir")
os.makedirs(lwq_workspace, exist_ok=True)
model_path = get_path(model_path)

# load weight
# breakpoint()
load_module(model, name, model_path, device=device)
# load_value(model, name + ".weight", model_path)
else:
# Put model on device explicitly
# TODO: refine it later, Put module on device one by one instead of the whole model
model.to(device)

# for only group_dim is 0 or only `transformers.Conv1D`, we need transpose weight.
if is_transformers_imported():
transpose = (group_dim == 0) ^ (isinstance(m, transformers.Conv1D))
Expand Down Expand Up @@ -202,8 +218,24 @@ def convert(
device=device,
)
new_module.pack(int_weight, scale, zp, m.bias)

# import pdb; pdb.set_trace()
if use_layer_wise:
# save and clean weight
from neural_compressor.torch.algorithms.layer_wise.utils import clean_module_weight

torch.save(new_module.state_dict(), os.path.join(lwq_workspace, f"{name}.pt"))
clean_module_weight(new_module)
del m
gc.collect()
if name == "":
return new_module
else:
set_module(model, name, new_module)

if use_layer_wise:
# register hooks
from neural_compressor.torch.algorithms.layer_wise.utils import register_weight_hooks

register_weight_hooks(model, model_path, device=device, clean_weight=True)
return model
1 change: 1 addition & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def rtn_entry(
"use_full_range": quant_config.use_full_range,
"use_mse_search": quant_config.use_mse_search,
"use_layer_wise": quant_config.use_layer_wise,
"model_path": quant_config.model_path,
"use_double_quant": quant_config.use_double_quant,
"double_quant_dtype": quant_config.double_quant_dtype,
"double_quant_bits": quant_config.double_quant_bits,
Expand Down
36 changes: 26 additions & 10 deletions test/3x/torch/quantization/weight_only/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def run_fn(model):
# GPTQ uses ValueError to reduce computation when collecting input data of the first block
# It's special for UTs, no need to add this wrapper in examples.
with pytest.raises(ValueError):
model(torch.tensor([[10, 20, 30]], dtype=torch.long).to(device))
model(torch.tensor([[40, 50, 60]], dtype=torch.long).to(device))
# model(torch.tensor([[10, 20, 30]], dtype=torch.long).to(device))
# model(torch.tensor([[40, 50, 60]], dtype=torch.long).to(device))
model(torch.tensor([[10, 20, 30]], dtype=torch.long))
model(torch.tensor([[40, 50, 60]], dtype=torch.long))


class TestGPTQQuant:
Expand Down Expand Up @@ -170,14 +172,28 @@ def test_act_order(self):
# compare atol, this case is an ideal case.
assert atol_false > atol_true, "act_order=True doesn't help accuracy, maybe is reasonable, please double check."

# def test_layer_wise(self):
# model = copy.deepcopy(self.tiny_gptj)
# quant_config = GPTQConfig(
# use_layer_wise=True,
# )
# model = quantize(model, quant_config, run_fn=run_fn)
# TODO: (Xin) not implemented

def test_layer_wise(self):
# model = copy.deepcopy(self.tiny_gptj)
model = copy.deepcopy(self.tiny_gptj)
quant_config = GPTQConfig()
model = prepare(model, quant_config)
run_fn(model)
model = convert(model)
q_label = model(self.example_inputs)[0]

from neural_compressor.torch.algorithms.layer_wise import load_empty_model
model = load_empty_model("hf-internal-testing/tiny-random-GPTJForCausalLM", torchscript=True)


quant_config = GPTQConfig(
use_layer_wise=True,
model_path="hf-internal-testing/tiny-random-GPTJForCausalLM"
)
model = quantize(model, quant_config, run_fn=run_fn)
out = model(self.example_inputs)[0]
atol_true = (out - q_label).amax()
print(out, atol_true)

@pytest.mark.parametrize("dtype", ["nf4", "int4"])
@pytest.mark.parametrize("double_quant_bits", [6])
@pytest.mark.parametrize("double_quant_group_size", [8, 256])
Expand Down
8 changes: 6 additions & 2 deletions test/3x/torch/quantization/weight_only/test_rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,17 @@ def test_mse_search(self):
assert torch.allclose(atol_false, atol_true, atol=0.012), "atol is very close, double checked the logic."

def test_layer_wise(self):
model = copy.deepcopy(self.tiny_gptj)
# model = copy.deepcopy(self.tiny_gptj)
from neural_compressor.torch.algorithms.layer_wise import load_empty_model
model = load_empty_model("hf-internal-testing/tiny-random-GPTJForCausalLM")
quant_config = RTNConfig(
use_layer_wise=True,
model_path="hf-internal-testing/tiny-random-GPTJForCausalLM",
)
model = prepare(model, quant_config)
model = convert(model)
# TODO: (Xin) not implemented
out = model(self.example_inputs)[0]
assert torch.equal(out, self.q_label), "use_layer_wise=True output should be same. Please double check."

@pytest.mark.parametrize(
"dtype",
Expand Down