Skip to content
Prev Previous commit
Next Next commit
update UT
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he committed Jun 20, 2024
commit 39f649ce8dc139d60c933e357a23deb795b2be1c
54 changes: 50 additions & 4 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int'
double_quant_use_sym: bool = False,
double_quant_group_size: int = 256,
# double quant
# quant lm_head
quant_lm_head: bool = False,
# Tuning space
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
Expand Down Expand Up @@ -272,6 +272,8 @@ class GPTQConfig(BaseConfig):
# layer wise params
"use_layer_wise",
"model_path",
# quant lm_head
"quant_lm_head",
# gptq params
"act_order",
"percdamp",
Expand All @@ -295,6 +297,8 @@ def __init__(
double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int'
double_quant_use_sym: bool = False,
double_quant_group_size: int = 256,
# double quant
quant_lm_head: bool = False,
# gptq params
act_order: bool = False,
percdamp: float = 0.01,
Expand All @@ -318,6 +322,7 @@ def __init__(
double_quant_bits (int): Number of bits used to represent double_quant scale. Default is 4.
double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric. Default is True.
double_quant_group_size (int): Size of double_quant groups. Default is 32.
quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False.
act_order (bool): Whether to sort Hessian's diagonal values to rearrange channel-wise
quantization order. Default is False.
percdamp (float): Percentage of Hessian's diagonal values' average, which will be added to
Expand All @@ -328,6 +333,7 @@ def __init__(
This option mitigate actorder's extra computational requirements.
Default is False.
"""
assert not quant_lm_head, "GPTQ doesn't support lm_head quantization currently, it's coming soon!"
super().__init__(white_list=white_list)
self.dtype = dtype
self.bits = bits
Expand All @@ -348,17 +354,27 @@ def __init__(
self.percdamp = percdamp
self.block_size = block_size
self.static_groups = static_groups
self.quant_lm_head = quant_lm_head
self._post_init() # initialize global & local configuration

@classmethod
def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
# TODO(Yi)
linear_gptq_config = GPTQConfig()
operators = [torch.nn.Linear]
operators = list(WOQ_WHITE_LIST)
supported_configs.append(OperatorConfig(config=linear_gptq_config, operators=operators))
cls.supported_configs = supported_configs

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
if not self.quant_lm_head:
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
self.set_local(usual_lm_head_names, GPTQConfig(dtype="fp32"))
config_mapping = super().to_config_mapping(config_list, model_info)
return config_mapping

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
filter_result = []
Expand Down Expand Up @@ -408,6 +424,8 @@ class AWQConfig(BaseConfig):
"double_quant_bits",
"double_quant_use_sym",
"double_quant_group_size",
# quant_lm_head
"quant_lm_head",
# AWQ params
"use_auto_scale",
"use_auto_clip",
Expand All @@ -431,6 +449,8 @@ def __init__(
double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int'
double_quant_use_sym: bool = True,
double_quant_group_size: int = 256,
# quant lm_head
quant_lm_head: bool = False,
# awq
use_auto_scale: bool = True,
use_auto_clip: bool = True,
Expand All @@ -453,6 +473,7 @@ def __init__(
double_quant_bits (int): Number of bits used to represent double_quant scale, default is 4.
double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric, default is True.
double_quant_group_size (int): Size of double_quant groups, default is 32.
quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False.
use_auto_scale (bool): Enables best scales search based on activation distribution, default is True.
use_auto_clip (bool): Enables clip range search. Defaults to True.
folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer,
Expand All @@ -473,6 +494,7 @@ def __init__(
self.double_quant_dtype = double_quant_dtype
self.double_quant_use_sym = double_quant_use_sym
self.double_quant_group_size = double_quant_group_size
self.quant_lm_head = quant_lm_head
self.use_auto_scale = use_auto_scale
self.use_auto_clip = use_auto_clip
self.folding = folding
Expand All @@ -483,10 +505,19 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
# TODO(Yi)
linear_awq_config = AWQConfig()
operators = [torch.nn.Linear, torch.nn.functional.linear]
operators = list(WOQ_WHITE_LIST)
supported_configs.append(OperatorConfig(config=linear_awq_config, operators=operators))
cls.supported_configs = supported_configs

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
if not self.quant_lm_head:
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
self.set_local(usual_lm_head_names, AWQConfig(dtype="fp32"))
config_mapping = super().to_config_mapping(config_list, model_info)
return config_mapping

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
filter_result = []
Expand Down Expand Up @@ -536,6 +567,8 @@ class TEQConfig(BaseConfig):
"double_quant_bits",
"double_quant_use_sym",
"double_quant_group_size",
# quant_lm_head
"quant_lm_head",
# TEQ params
"absorb_to_layer",
"folding",
Expand All @@ -558,6 +591,8 @@ def __init__(
double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int'
double_quant_use_sym: bool = True,
double_quant_group_size: int = 256,
# double quant
quant_lm_head: bool = False,
# teq
absorb_to_layer: dict = {},
folding: bool = True,
Expand All @@ -579,6 +614,7 @@ def __init__(
double_quant_bits (int): Number of bits used to represent double_quant scale, default is 4.
double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric, default is True.
double_quant_group_size (int): Size of double_quant groups, default is 32.
quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False.
absorb_to_layer (bool): The layer dict that scale can be absorbed, default is {}.
folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer,
default is False.
Expand All @@ -598,6 +634,7 @@ def __init__(
self.double_quant_dtype = double_quant_dtype
self.double_quant_use_sym = double_quant_use_sym
self.double_quant_group_size = double_quant_group_size
self.quant_lm_head = quant_lm_head
self.absorb_to_layer = absorb_to_layer
self.folding = folding
self._post_init()
Expand All @@ -607,10 +644,19 @@ def register_supported_configs(cls) -> List[OperatorConfig]:
supported_configs = []
# TODO(Yi)
linear_teq_config = TEQConfig()
operators = [torch.nn.Linear, torch.nn.functional.linear]
operators = list(WOQ_WHITE_LIST)
supported_configs.append(OperatorConfig(config=linear_teq_config, operators=operators))
cls.supported_configs = supported_configs

def to_config_mapping(
self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None
) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]:
if not self.quant_lm_head:
usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"]
self.set_local(usual_lm_head_names, TEQConfig(dtype="fp32"))
config_mapping = super().to_config_mapping(config_list, model_info)
return config_mapping

@staticmethod
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
filter_result = []
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,65 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

from neural_compressor.torch.algorithms.weight_only.hqq.bitpack import Packer
from neural_compressor.torch.algorithms.weight_only.hqq.config import (
HQQModuleConfig,
QTensorConfig,
default_hqq_module_config,
default_scale_quant_config,
default_weight_quant_config,
default_zero_quant_config,
)
from neural_compressor.torch.algorithms.weight_only.hqq.qtensor import QTensor, QTensorMetaInfo


def test_default_hqq_module_config():
config = default_hqq_module_config
print(config)
assert isinstance(config, HQQModuleConfig)
assert config.weight == default_weight_quant_config
assert config.zero == default_zero_quant_config
assert config.scale == default_scale_quant_config


def test_default_weight_quant_config():
config = default_weight_quant_config
assert isinstance(config, QTensorConfig)
assert config.nbits == 4
assert config.channel_wise is True


def test_default_zero_quant_config():
config = default_zero_quant_config
assert isinstance(config, QTensorConfig)
assert config.nbits == 8
assert config.channel_wise is False


def test_default_scale_quant_config():
config = default_scale_quant_config
assert isinstance(config, QTensorConfig)
assert config.nbits == 8
assert config.channel_wise is True


def test_qtensor_meta_info():
meta_info = QTensorMetaInfo
print(meta_info)


@pytest.mark.parametrize("nbits", [2, 3, 4, 8])
def test_packer(nbits):
# TODO: add test for 3 bits
range_max = 2**nbits
dims = 16 if nbits != 3 else 10
W = torch.randint(0, range_max, (dims, dims)).to(torch.uint8)
W_pack = Packer.get_pack_fn(nbits)(W)
W_pack_unpack = Packer.get_unpack_fn(nbits)(W_pack)
assert torch.allclose(W, W_pack_unpack)
print("Packer test passed!")


class TestQTensor:
def test_q_tensor(self):
in_feats = 3
Expand Down
44 changes: 0 additions & 44 deletions test/3x/torch/quantization/weight_only/hqq/test_hqq_config.py

This file was deleted.

16 changes: 0 additions & 16 deletions test/3x/torch/quantization/weight_only/hqq/test_packer.py

This file was deleted.

Loading