Skip to content
Prev Previous commit
Next Next commit
add UTs
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he committed Jun 19, 2024
commit 0ae3c7b119a6d3d4fbee4fdb4477e3eaa074388a
21 changes: 19 additions & 2 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,20 @@
# limitations under the License.


import copy
from collections import OrderedDict

import torch

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger, set_module
from neural_compressor.torch.utils import (
get_accelerator,
get_attr,
is_transformers_imported,
logger,
set_attr,
set_module,
)

from .utility import cast_fp8, quant_tensor, search_clip

Expand Down Expand Up @@ -64,6 +72,7 @@ def convert(
quantile=1.0,
use_full_range=False,
use_mse_search=False,
quant_lm_head=False,
*args,
**kwargs,
):
Expand All @@ -80,8 +89,10 @@ def convert(
quantile (float, optional): percentile of clip. Defaults to 1.0.
use_full_range (bool, optional): Choose sym range whether use -2**(bits-1).
Defaults to False.
use_mse_search (bool, optional): Whether search clip range.
use_mse_search (bool, optional): Whether to search clip range.
Defaults to True.
quant_lm_head (bool, optional): Whether to quantize the lm_head layer.
Defaults to False.

Returns:
model: fake quantized torch module
Expand All @@ -93,6 +104,12 @@ def convert(
# TODO: refine it later, Put module on device one by one instead of the whole model
model.to(device)

# for transformers model. If lm_head is tied from embedding, we deepcopy it.
if quant_lm_head and getattr(getattr(model, "config", None), "tie_word_embeddings", False):
for key in model._tied_weights_keys:
weight = get_attr(model, key)
set_attr(model, key, copy.deepcopy(weight))

assert isinstance(model, torch.nn.Module), "only support torch module"
if is_transformers_imported():
supported_layers = (torch.nn.Linear, transformers.Conv1D)
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def rtn_entry(
}

quantizer = get_quantizer(model, quantizer_cls=RTNQuantizer, quant_config=weight_config)
model = quantizer.execute(model, mode=mode)
model = quantizer.execute(model, mode=mode, quant_lm_head=quant_config.quant_lm_head)
model.qconfig = configs_mapping
model.save = MethodType(save, model)
postprocess_model(model, mode, quantizer)
Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def set_module(model, op_name, new_module):
setattr(second_last_module, name_list[-1], new_module)


get_attr = fetch_module
set_attr = set_module


def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) -> List[Tuple[str, str]]:
module_dict = dict(model.named_modules())
filter_result = []
Expand Down
28 changes: 28 additions & 0 deletions test/3x/torch/quantization/weight_only/test_rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,34 @@ def test_mse_search(self):
except:
assert torch.allclose(atol_false, atol_true, atol=0.012), "atol is very close, double checked the logic."

def test_quant_lm_head(self):
# tie_word_embeddings=false
gptj_model = transformers.AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM",
device_map=device,
)
lm_head_id = id(gptj_model.lm_head.weight)
assert id(gptj_model.transformer.wte.weight) != lm_head_id, "The lm_head weight is tied, please check!"
quant_config = RTNConfig(quant_lm_head=True)
model = prepare(gptj_model, quant_config)
model = convert(model)

# tie_word_embeddings=true
opt_model = transformers.AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-random-OPTForCausalLM",
device_map=device,
)
lm_head_id = id(opt_model.lm_head.weight)
assert (
id(opt_model.model.decoder.embed_tokens.weight) == lm_head_id
), "The lm_head weight is not tied, please check!"
quant_config = RTNConfig(quant_lm_head=True)
model = prepare(opt_model, quant_config)
model = convert(model)
assert (
id(model.model.decoder.embed_tokens.weight) == lm_head_id
), "The tied lm_head weight is not deep copied, please check!"

def test_layer_wise(self):
model = copy.deepcopy(self.tiny_gptj)
quant_config = RTNConfig(
Expand Down