Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix 3x ipex static quant regression
Signed-off-by: Cheng, Zixuan <[email protected]>
  • Loading branch information
violetch24 committed Jun 13, 2024
commit b81277ce72bb0f10a692d5e1f78a0606bb544ecc
55 changes: 54 additions & 1 deletion neural_compressor/torch/algorithms/smooth_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

from neural_compressor.torch.algorithms.static_quant import (
CpuInfo,
Statistics,
TransformerBasedModelBlockPatternDetector,
dump_model_op_stats,
generate_activation_observer,
get_quantizable_ops_from_cfgs,
ipex_config_path,
Expand Down Expand Up @@ -251,6 +251,59 @@ def cfg_to_qconfig(
return None


def dump_model_op_stats(user_cfg):
"""This is a function to dump quantizable ops of model to user.

Args:
user_cfg (dict): quantization config
Returns:
None
"""
res = dict()
for k, v in user_cfg.items():
op_type_list = k[-1].split("><")
op_type = ""
for op in op_type_list:
if "class" in op:
op_type = (
op[op.rfind(".") + 1 : op.rfind("'")]
if op_type == ""
else op_type + "&" + op[op.rfind(".") + 1 : op.rfind("'")]
)
elif "method" in op:
start = op.find("'") + 1
if start > 1:
op_type = (
op[start : op.find("'", start)]
if op_type == ""
else op_type + "&" + op[start : op.find("'", start)]
)
else:
start = op.find("method") + 7
op_type = (
op[start : op.find(" ", start)]
if op_type == ""
else op_type + "&" + op[start : op.find(" ", start)]
)
else:
op_type = op if op_type == "" else op_type + "&" + op
if op_type not in res.keys():
res[op_type] = {"INT8": 0, "BF16": 0, "FP32": 0}
if v["weight"]["dtype"] == "int8":
res[op_type]["INT8"] += 1
elif v["weight"]["dtype"] == "fp32":
res[op_type]["FP32"] += 1

output_data = [
[op_type, sum(res[op_type].values()), res[op_type]["INT8"], res[op_type]["BF16"], res[op_type]["FP32"]]
for op_type in res.keys()
]

Statistics(
output_data, header="Mixed Precision Statistics", field_names=["Op Type", "Total", "INT8", "BF16", "FP32"]
).print_stat()


def get_parent(node, all_parents=False): # pragma: no cover
if node.inputs() is None:
return None
Expand Down
47 changes: 14 additions & 33 deletions neural_compressor/torch/algorithms/static_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,14 @@
"<method 'add' of 'torch._C.TensorBase' objects>": "add", # for IPEX >= 2.2
"<class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>": "AdaptiveAvgPool2d",
"Linear_Relu": "Linear",
"Linear_add": "Linear",
"<class 'torch.nn.modules.linear.Linear'>": "Linear",
"<class 'torch.nn.modules.pooling.MaxPool2d'>": "MaxPool2d",
"re": {"<built-in method matmul of type object at": "matmul"},
"re": {
"<built-in method matmul of type object at": "matmul",
"<built-in method add of type object at": "add",
"<built-in method bmm of type object at": "bmm",
},
}

BLOCK_PATTERNS = [
Expand Down Expand Up @@ -85,6 +90,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
Returns:
cfgs (dict): updated configs.
"""
ori_user_cfg = copy.deepcopy(user_cfg)
tmp_user_cfg = OrderedDict()
for op in user_cfg: # map ipex op_name to pt op_name
for i, op_name in enumerate(op):
Expand All @@ -94,9 +100,9 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
ori_op = (tuple(ops), unify_op_type_mapping_ipex[op_infos_from_cfgs[ops]["op_type"]])
tmp_user_cfg[((ori_op[0],), ori_op[1])] = user_cfg[op]
break
user_cfg = tmp_user_cfg
for op_name in user_cfg:
inc_op_cfg = user_cfg[op_name]

for op_name in tmp_user_cfg:
inc_op_cfg = tmp_user_cfg[op_name]
for i, name in enumerate(op_name[0]):
# to int8
ipex_op_cfg = op_infos_from_cfgs[name]
Expand Down Expand Up @@ -154,7 +160,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
else:
pass
cfgs[name[0]][name[1]][name[2]] = ipex_op_cfg
return cfgs, user_cfg
return cfgs, ori_user_cfg


def generate_activation_observer(scheme, algorithm, smooth_quant=False, smooth_quant_enable=False): # pragma: no cover
Expand Down Expand Up @@ -333,8 +339,8 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
elif "method" in ipex_op_type: # "<method 'add' of 'torch._C._TensorBase' objects>"
method = ipex_op_type.split("'")[1]
op_name_info.append((module_fqn, method))
elif "Convolution" in ipex_op_type: # "Convolution_Relu"
op_name_info.append((module_fqn, "Conv2d"))
elif "_" in ipex_op_type: # "Convolution_Relu", "Linear_Relu"
op_name_info.append((module_fqn, ipex_op_type.split("_")[0]))
else:
re_flag = False
for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items():
Expand Down Expand Up @@ -394,32 +400,7 @@ def dump_model_op_stats(user_cfg):
"""
res = dict()
for k, v in user_cfg.items():
op_type_list = k[-1].split("><")
op_type = ""
for op in op_type_list:
if "class" in op:
op_type = (
op[op.rfind(".") + 1 : op.rfind("'")]
if op_type == ""
else op_type + "&" + op[op.rfind(".") + 1 : op.rfind("'")]
)
elif "method" in op:
start = op.find("'") + 1
if start > 1:
op_type = (
op[start : op.find("'", start)]
if op_type == ""
else op_type + "&" + op[start : op.find("'", start)]
)
else:
start = op.find("method") + 7
op_type = (
op[start : op.find(" ", start)]
if op_type == ""
else op_type + "&" + op[start : op.find(" ", start)]
)
else:
op_type = op if op_type == "" else op_type + "&" + op
op_type = k[1]
if op_type not in res.keys():
res[op_type] = {"INT8": 0, "BF16": 0, "FP32": 0}
if v["weight"]["dtype"] == "int8":
Expand Down