Skip to content
Prev Previous commit
Next Next commit
update hqq
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he committed Jun 20, 2024
commit b208ae3b21a67fc127174bb0a6a8384a50d9adc9
31 changes: 25 additions & 6 deletions docs/3x/PT_WeightOnlyQuant.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

PyTorch Weight Only Quantization
===============

- [Introduction](#introduction)
- [Supported Matrix](#supported-matrix)
- [Usage](#usage)
Expand Down Expand Up @@ -28,7 +29,6 @@ Besides, as mentioned in many papers[1][2], activation quantization is the main

Theoretically, round-to-nearest (RTN) is the most straightforward way to quantize weight using scale maps. However, when the number of bits is small (e.g. 3), the MSE loss is larger than expected. A group size is introduced to reduce elements using the same scale to improve accuracy.


## Supported Matrix

| Algorithms/Backend | PyTorch eager mode |
Expand Down Expand Up @@ -58,25 +58,28 @@ Theoretically, round-to-nearest (RTN) is the most straightforward way to quantiz
WeightOnlyQuant quantization for PyTorch is using prepare and convert [APIs](./PyTorch.md#quantization-apis).

#### Common arguments

| Config | Capability |
|---|---|
| dtype (str)| ['int', 'nf4', 'fp4'] |
| bits (int)| [1, ..., 8] |
| group_size (int)| [-1, 1, ..., $C_{in}$] |
| use_sym (bool)| [True, False] |
| quant_lm_head (bool)| [False, True] |
| use_double_quant (bool) | [True, False] |
| double_quant_dtype (str) | ['int'] |
| double_quant_bits (int) | [1, ..., bits] |
| double_quant_use_sym (bool) | [True, False] |
| double_quant_group_size (int) | [-1, 1, ..., $C_{in}$] |

Notes:

- *group_size = -1* refers to **per output channel quantization**. Taking a linear layer (input channel = $C_{in}$, output channel = $C_{out}$) for instance, when *group size = -1*, quantization will calculate total $C_{out}$ quantization parameters. Otherwise, when *group_size = gs* quantization parameters are calculate with every $gs$ elements along with the input channel, leading to total $C_{out} \times (C_{in} / gs)$ quantization parameters.
- 4-bit NormalFloat(NF4) is proposed in QLoRA[7]. 'fp4' includes [fp4_e2m1](../../neural_compressor/adaptor/torch_utils/weight_only.py#L37) and [fp4_e2m1_bnb](https://github.com/TimDettmers/bitsandbytes/blob/18e827d666fa2b70a12d539ccedc17aa51b2c97c/bitsandbytes/functional.py#L735). By default, fp4 refers to fp4_e2m1_bnb.
- Only RTN and GPTQ support double quant.

- Only RTN and GPTQ support double quant.

#### RTN

| rtn_args | comments | default value |
|----------|-------------|-------------------------------------------------------------------|
| group_dim (int) | Dimension for grouping | 1 |
Expand All @@ -86,6 +89,7 @@ Notes:
| model_path (str) | Model path that is used to load state_dict per layer | |

> **Notes:** `model_path` is only used when use_layer_wise=True. `layer-wise` is stay-tuned.

``` python
# Quantization code
from neural_compressor.torch.quantization import prepare, convert, RTNConfig
Expand All @@ -96,6 +100,7 @@ model = convert(model)
```

#### GPTQ

| gptq_args | comments | default value |
|----------|-------------|-------------------------------------------------------------------|
| use_mse_search (bool) | Enables mean squared error (MSE) search | False
Expand All @@ -107,6 +112,7 @@ model = convert(model)
| block_size (int) | Execute GPTQ quantization per block, block shape = [C_out, block_size] | 128 |
| static_groups (bool) | Whether to calculate group wise quantization parameters in advance. This option mitigate actorder's extra computational requirements. | False. |
> **Note:** `model_path` is only used when use_layer_wise=True. `layer-wise` is stay-tuned.

``` python
# Quantization code
from neural_compressor.torch.quantization import prepare, convert, GPTQConfig
Expand All @@ -118,6 +124,7 @@ model = convert(model)
```

#### AutoRound

| autoround_args | comments | default value |
|----------|-------------|-------------------------------------------------------------------|
| enable_full_range (bool) | Whether to enable full range quantization | False
Expand All @@ -138,6 +145,7 @@ model = convert(model)
| not_use_best_mse (bool) | Whether to use mean squared error | False |
| dynamic_max_gap (int) | The dynamic maximum gap | -1 |
| scale_dtype (str) | The data type of quantization scale to be used, different kernels have different choices | "float16" |

``` python
# Quantization code
from neural_compressor.torch.quantization import prepare, convert, AutoRoundConfig
Expand All @@ -149,6 +157,7 @@ model = convert(model)
```

#### AWQ

| awq_args | comments | default value |
|----------|-------------|-------------------------------------------------------------------|
| group_dim (int) | Dimension for grouping | 1 |
Expand All @@ -159,6 +168,7 @@ model = convert(model)
| use_auto_clip (bool) | Enables clip range search | True |
| folding(bool) | Allow insert mul before linear when the scale cannot be absorbed by last layer | False. |
> **Notes:** `layer-wise` is stay-tuned.

``` python
# Quantization code
from neural_compressor.torch.quantization import prepare, convert, AWQConfig
Expand All @@ -170,6 +180,7 @@ model = convert(model)
```

#### TEQ

| teq_args | comments | default value |
|----------|-------------|-------------------------------------------------------------------|
| group_dim (int) | Dimension for grouping | 1 |
Expand All @@ -179,6 +190,7 @@ model = convert(model)
| use_double_quant (bool) | Enables double quantization | False |
| folding(bool) | Allow insert mul before linear when the scale cannot be absorbed by last layer | False |
> **Notes:** `layer-wise` is stay-tuned.

``` python
# Quantization code
from neural_compressor.torch.quantization import prepare, convert, TEQConfig
Expand All @@ -190,12 +202,13 @@ model = convert(model)
```

#### HQQ

| hqq_args | comments | default value |
|----------|-------------|-------------------------------------------------------------------|
| quant_zero (bool) | Whether to quantize zero point | True |
| quant_scale: (bool) | Whether to quantize scale: point | False |
| scale_quant_group_size (int) | The group size for quantizing scale | 128 |
| skip_lm_head (bool) | Whether to skip for quantizing lm_head | True |

``` python
# Quantization code
from neural_compressor.torch.quantization import prepare, convert, HQQConfig
Expand All @@ -205,10 +218,13 @@ model = prepare(model, quant_config)
run_fn(model) # calibration
model = convert(model)
```

### Specify Quantization Rules

Intel(R) Neural Compressor support specify quantization rules by operator name or operator type. Users can set `local` in dict or use `set_local` method of config class to achieve the above purpose.

1. Example of setting `local` from a dict

```python
quant_config = {
"rtn": {
Expand All @@ -226,15 +242,19 @@ quant_config = {
}
}
```

2. Example of using `set_local`

```python
quant_config = RTNConfig()
lm_head_config = RTNConfig(dtype="fp32")
quant_config.set_local("lm_head", lm_head_config)
```

### Saving and Loading

The saved_results folder contains two files: quantized_model.pt and qconfig.json, and the generated model is a quantized model. The quantitative model will include WeightOnlyLinear. To support low memory inference, Intel(R) Neural Compressor implemented WeightOnlyLinear, a torch.nn.Module, to compress the fake quantized fp32 model. Since torch does not provide flexible data type storage, WeightOnlyLinear combines low bits data into a long date type, such as torch.int8 and torch.int32. Low bits data includes weights and zero points. When using WeightOnlyLinear for inference, it will restore the compressed data to float32 and run torch linear function.

```python
# Quantization code
from neural_compressor.torch.quantization import prepare, convert, RTNConfig
Expand All @@ -255,7 +275,6 @@ loaded_model = load(
) # Please note that the original_model parameter passes the original model.
```


## Examples

Users can also refer to [examples](https://github.com/intel/neural-compressor/blob/master/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only) on how to quantize a model with WeightOnlyQuant.
Expand All @@ -272,6 +291,6 @@ Users can also refer to [examples](https://github.com/intel/neural-compressor/bl

[5]. Cheng, Wenhua, et al. "Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs" arXiv preprint arXiv:2309.05516 (2023).

[6]. Badri, Hicham and Shaji, Appu. "Half-Quadratic Quantization of Large Machine Learning Models." [Online] Available: https://mobiusml.github.io/hqq_blog/ (2023).
[6]. Badri, Hicham and Shaji, Appu. "Half-Quadratic Quantization of Large Machine Learning Models." [Online] Available: <https://mobiusml.github.io/hqq_blog/> (2023).

[7]. Dettmers, Tim, et al. "Qlora: Efficient finetuning of quantized llms." arXiv preprint arXiv:2305.14314 (2023).
13 changes: 0 additions & 13 deletions neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(self, quant_config: ConfigMappingType) -> None:
Args:
quant_config (ConfigMappingType): quantization config for ops.
"""
quant_config = self._parse_hqq_configs_mapping(quant_config)
super().__init__(quant_config=quant_config)

@torch.no_grad()
Expand Down Expand Up @@ -142,15 +141,3 @@ def _convert_hqq_module_config(self, config) -> HQQModuleConfig:
hqq_module_config = HQQModuleConfig(weight=weight_qconfig, scale=scale_qconfig, zero=zero_qconfig)
logger.debug(hqq_module_config)
return hqq_module_config

def _parse_hqq_configs_mapping(self, configs_mapping):
qconfig_mapping = {}
for (op_name, op_type), quant_config in configs_mapping.items():
if quant_config.skip_lm_head and "lm_head" in op_name:
logger.warning("Skip quantizing %s due to `skip_lm_head` is True.", op_name)
continue
if quant_config is not None and quant_config.dtype == "fp32":
logger.warning("Fallback %s.", op_name)
continue
qconfig_mapping[op_name] = self._convert_hqq_module_config(quant_config)
return qconfig_mapping
35 changes: 23 additions & 12 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def __init__(
double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int'
double_quant_use_sym: bool = True,
double_quant_group_size: int = 256,
# double quant
# quant lm_head
quant_lm_head: bool = False,
# teq
absorb_to_layer: dict = {},
Expand Down Expand Up @@ -1231,7 +1231,8 @@ class HQQConfig(BaseConfig):
"quant_zero",
"quant_scale",
"scale_quant_group_size",
"skip_lm_head",
# quant_lm_head
"quant_lm_head",
]
supported_configs: List[OperatorConfig] = []

Expand All @@ -1243,7 +1244,8 @@ def __init__(
quant_zero: bool = True,
quant_scale: bool = False,
scale_quant_group_size: int = 128,
skip_lm_head: bool = True,
# quant lm_head
quant_lm_head: bool = False,
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
super().__init__(white_list=white_list)
Expand All @@ -1253,9 +1255,18 @@ def __init__(
self.quant_zero = quant_zero
self.quant_scale = quant_scale
self.scale_quant_group_size = scale_quant_group_size
self.skip_lm_head = skip_lm_head
self.quant_lm_head = quant_lm_head
self._post_init()

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
# TODO: to be refined
supported_configs = []
linear_hqq_config = HQQConfig()
operators = list(WOQ_WHITE_LIST)
supported_configs.append(OperatorConfig(config=linear_hqq_config, operators=operators))
cls.supported_configs = supported_configs

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
filter_result = []
Expand All @@ -1265,14 +1276,14 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
filter_result.append(pair)
return filter_result

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
# TODO: to be refined
supported_configs = []
linear_hqq_config = HQQConfig()
operators = [torch.nn.Linear]
supported_configs.append(OperatorConfig(config=linear_hqq_config, operators=operators))
cls.supported_configs = supported_configs
def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
if not self.quant_lm_head:
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
self.set_local(usual_lm_head_names, HQQConfig(dtype="fp32"))
config_mapping = super().to_config_mapping(config_list, model_info)
return config_mapping

@classmethod
def get_config_set_for_tuning(cls) -> Union[None, "HQQConfig", List["HQQConfig"]]:
Expand Down
46 changes: 34 additions & 12 deletions test/3x/torch/quantization/weight_only/hqq/test_hqq_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@

import pytest
import torch
import transformers
from transformers import AutoModelForCausalLM

from neural_compressor.torch.algorithms.weight_only.hqq.config import HQQModuleConfig, QTensorConfig, hqq_global_option
from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear
from neural_compressor.torch.quantization import HQQConfig, convert, get_default_hqq_config, prepare, quantize
from neural_compressor.torch.utils import accelerator

device = accelerator.current_device_name()


def _common_cpu_test(nbits=4, group_size=64, quant_zero=True, quant_scale=False, scale_quant_group_size=128):
Expand Down Expand Up @@ -65,10 +70,9 @@ def force_not_half(self, monkeypatch):
monkeypatch.setattr(hqq_global_option, "use_half", False)

def test_hqq_quant(self, force_use_cpu, force_not_half):
from neural_compressor.torch.quantization import convert, get_default_hqq_config, prepare, quantize

hqq_global_option.use_half = False
fp32_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
fp32_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-OPTForCausalLM")
example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long, device="cpu")
# test_default_config
quant_config = get_default_hqq_config()
Expand All @@ -88,7 +92,6 @@ def test_hqq_quant(self, force_use_cpu, force_not_half):
), "The results of calling `convert` + `prepare` and calling `quantize` should be equal."

def test_hqq_fallback(self, force_use_cpu, force_not_half):
from neural_compressor.torch.quantization import HQQConfig, convert, prepare

class ToyModel(torch.nn.Module):
def __init__(self):
Expand All @@ -106,6 +109,34 @@ def forward(self, x):
assert type(qmodel.fc1).__name__ == torch.nn.Linear.__name__, f"Expect fallback fc1, but get {type(qmodel.fc1)}"
assert type(qmodel.fc2).__name__ != torch.nn.Linear.__name__, f"Expect quantize fc2, but get {type(qmodel.fc2)}"

def test_quant_lm_head(self, force_use_cpu, force_not_half):
# tie_word_embeddings=false
gptj_model = transformers.AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM",
device_map=device,
)
lm_head_id = id(gptj_model.lm_head.weight)
assert id(gptj_model.transformer.wte.weight) != lm_head_id, "The lm_head weight is tied, please check!"
quant_config = HQQConfig(quant_lm_head=True)
model = prepare(gptj_model, quant_config)
model = convert(model)

# tie_word_embeddings=true
opt_model = transformers.AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-random-OPTForCausalLM",
device_map=device,
)
lm_head_id = id(opt_model.lm_head.weight)
assert (
id(opt_model.model.decoder.embed_tokens.weight) == lm_head_id
), "The lm_head weight is not tied, please check!"
quant_config = HQQConfig(quant_lm_head=True)
model = prepare(opt_model, quant_config)
model = convert(model)
assert (
id(model.model.decoder.embed_tokens.weight) == lm_head_id
), "The tied lm_head weight is not deep copied, please check!"

@pytest.mark.parametrize(
"nbits, group_size, quant_zero, quant_scale, scale_quant_group_size",
[
Expand Down Expand Up @@ -134,12 +165,3 @@ def test_hqq_module_cpu(
quant_scale=quant_scale,
scale_quant_group_size=scale_quant_group_size,
)


# _common_cpu_test(
# nbits=4,
# group_size=64,
# quant_zero=False,
# quant_scale=False,
# scale_quant_group_size=128
# )