Skip to content
Next Next commit
Support xpu for ipex static quant
Signed-off-by: gta <[email protected]>
  • Loading branch information
gta authored and gta committed Jul 12, 2024
commit 3cf0e089896352b13d721e470c7c5d180dd92288
110 changes: 70 additions & 40 deletions neural_compressor/torch/algorithms/static_quant/static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@

from neural_compressor.torch.algorithms import Quantizer
from neural_compressor.torch.utils import logger
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .utility import (
CpuInfo,
cfg_to_qconfig,
dump_model_op_stats,
generate_xpu_qconfig,
get_ipex_version,
get_quantizable_ops_recursively,
ipex_config_path,
Expand Down Expand Up @@ -68,45 +70,64 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
Returns:
A prepared model.
"""
device = auto_detect_accelerator().current_device()
assert example_inputs is not None, "Please provide example_inputs for static quantization."

_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(
model, example_inputs
)
# update json file in ipex_config_path; map ipex op_name to pt op_name
self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
model.eval()
if device == "cpu":
_, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively(
model, example_inputs
)
# update json file in ipex_config_path; map ipex op_name to pt op_name
self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
else:
model = model.to("xpu")

use_bf16 = self.quant_config.get("use_bf16", None)
model.eval()

# Check save_qconf_summary part is a workaround for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"):
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig

if ipex_ver.release >= Version("2.1").release:
# HistogramObserver will cause a performance issue.
# static_qconfig = ipex.quantization.default_static_qconfig_mapping
qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
)
from torch.ao.quantization import QConfigMapping

static_qconfig = QConfigMapping().set_global(qconfig)
else:
static_qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric),
)
if isinstance(example_inputs, dict):
model = ipex.quantization.prepare(
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
)
# Sometimes the prepared model from get_op_capablitiy loss this attributes
if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover
from torch.ao.quantization import HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver, QConfig

if device != "cpu": # pragma: no cover
from torch.quantization.quantize_jit import prepare_jit

with torch.no_grad():
modelJit = torch.jit.trace(model, example_inputs)
qconfig = generate_xpu_qconfig(self.quant_config)
model = prepare_jit(modelJit, qconfig, inplace)
else:
model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace)
if ipex_ver.release >= Version("2.1").release:
# HistogramObserver will cause a performance issue.
# static_qconfig = ipex.quantization.default_static_qconfig_mapping
qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric
),
)
from torch.ao.quantization import QConfigMapping

static_qconfig = QConfigMapping().set_global(qconfig)
else:
static_qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric
),
)
if isinstance(example_inputs, dict):
model = ipex.quantization.prepare(
model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace
)
else:
model = ipex.quantization.prepare(
model, static_qconfig, example_inputs=example_inputs, inplace=inplace
)

if device == "cpu":
model.load_qconf_summary(qconf_summary=ipex_config_path)

model.load_qconf_summary(qconf_summary=ipex_config_path)
return model

def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
Expand All @@ -120,22 +141,31 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
Returns:
A quantized model.
"""
device = auto_detect_accelerator().current_device()
use_bf16 = self.quant_config.get("use_bf16", None)

from neural_compressor.torch.algorithms.static_quant import save

model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace)
if device != "cpu": # pragma: no cover
from torch.quantization.quantize_jit import convert_jit

model = convert_jit(model, inplace)
simple_inference(model, example_inputs, iterations=2)
dump_model_op_stats(self.quant_config["op"])
else:
model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace)

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
model.ipex_config_path = ipex_config_path
dump_model_op_stats(self.user_cfg)

dump_model_op_stats(self.user_cfg)
model.ori_save = model.save
model.save = MethodType(save, model)

logger.info("Static quantization done.")
model.ori_save = model.save
model.save = MethodType(save, model)
return model


Expand Down
86 changes: 85 additions & 1 deletion neural_compressor/torch/algorithms/static_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@

try:
import intel_extension_for_pytorch as ipex
import prettytable as pt
except: # pragma: no cover
pass

from neural_compressor.common.utils import DEFAULT_WORKSPACE, CpuInfo
from neural_compressor.torch.utils import Statistics, get_ipex_version, get_torch_version, logger
from neural_compressor.torch.utils import get_ipex_version, get_torch_version, logger

version = get_torch_version()
ipex_ver = get_ipex_version()
Expand Down Expand Up @@ -163,6 +164,47 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
return cfgs, ori_user_cfg


def generate_xpu_qconfig(tune_cfg):
# qconfig observer & config constants for ipex-xpu
from torch.ao.quantization import HistogramObserver, MinMaxObserver, QConfig

act_observer_minmax_asym = MinMaxObserver.with_args(quant_min=0, quant_max=127)
act_observer_minmax_sym = MinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127
)
act_observer_kl_asym = HistogramObserver.with_args(quant_min=0, quant_max=127)
act_observer_kl_sym = HistogramObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127
)
# no tuning for granularity due to tuning space
weight_observer_minmax_sym = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)

qconfig = {}
user_cfg = copy.deepcopy(tune_cfg["op"])
for _, cfg in user_cfg.items():
act_algo = cfg["activation"]["algorithm"]
act_sym = cfg["activation"]["scheme"]
break

if act_algo == "minmax":
if act_sym == "sym":
activation = act_observer_minmax_sym
else:
activation = act_observer_minmax_asym
else:
if act_sym == "sym":
activation = act_observer_kl_sym
else:
activation = act_observer_kl_asym

qconfig[""] = QConfig(activation=activation, weight=weight_observer_minmax_sym)

for (op_name, op_type), cfg in user_cfg.items():
if cfg["weight"]["dtype"] == "fp32":
qconfig[op_name] = None
return qconfig


def generate_activation_observer(
scheme, algorithm, smooth_quant=False, smooth_quant_enable=False, alpha=0.5
): # pragma: no cover
Expand Down Expand Up @@ -566,6 +608,48 @@ def get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_ids
return quantizable_ops


class Statistics: # pragma: no cover
"""The statistics printer."""

def __init__(self, data, header, field_names, output_handle=logger.info):
"""Init a Statistics object.

Args:
data: The statistics data
header: The table header
field_names: The field names
output_handle: The output logging method
"""
self.field_names = field_names
self.header = header
self.data = data
self.output_handle = output_handle
self.tb = pt.PrettyTable(min_table_width=40)

def print_stat(self):
"""Print the statistics."""
valid_field_names = []
for index, value in enumerate(self.field_names):
if index < 2:
valid_field_names.append(value)
continue

if any(i[index] for i in self.data):
valid_field_names.append(value)
self.tb.field_names = valid_field_names
for i in self.data:
tmp_data = []
for index, value in enumerate(i):
if self.field_names[index] in valid_field_names:
tmp_data.append(value)
if any(tmp_data[1:]):
self.tb.add_row(tmp_data)
lines = self.tb.get_string().split("\n")
self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|")
for i in lines:
self.output_handle(i)


class TransformerBasedModelBlockPatternDetector: # pragma: no cover
"""Detect the attention block and FFN block in transformer-based model."""

Expand Down
Loading