Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion neural_compressor/torch/algorithms/weight_only/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,4 +1132,3 @@ def convert_dtype_str2torch(str_dtype):
return torch.bfloat16
else:
assert False, "Unsupported str dtype {} to torch dtype".format(str_dtype)

9 changes: 8 additions & 1 deletion neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,14 @@
StaticQuantConfig,
TEQConfig,
)
from neural_compressor.torch.utils import get_quantizer, is_ipex_imported, logger, postprocess_model, register_algo, dump_model_op_stats
from neural_compressor.torch.utils import (
dump_model_op_stats,
get_quantizer,
is_ipex_imported,
logger,
postprocess_model,
register_algo,
)
from neural_compressor.torch.utils.constants import PT2E_DYNAMIC_QUANT, PT2E_STATIC_QUANT


Expand Down
110 changes: 56 additions & 54 deletions neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from typing import Callable, Dict, List, Tuple, Union

import torch
import prettytable as pt
import torch
from typing_extensions import TypeAlias

from neural_compressor.common.utils import LazyImport, Mode, logger
Expand Down Expand Up @@ -165,6 +165,7 @@ def postprocess_model(model, mode, quantizer):
if getattr(model, "quantizer", False):
del model.quantizer


class Statistics: # pragma: no cover
"""The statistics printer."""

Expand Down Expand Up @@ -205,58 +206,59 @@ def print_stat(self):
self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|")
for i in lines:
self.output_handle(i)



def dump_model_op_stats(mode, tune_cfg):
"""This is a function to dump quantizable ops of model to user.
"""This is a function to dump quantizable ops of model to user.

Args:
model (object): input model
tune_cfg (dict): quantization config
Returns:
None
"""
if mode == Mode.PREPARE:
return
res = {}
# collect all dtype info and build empty results with existing op_type
dtype_set = set()
for op, config in tune_cfg.items():
op_type = op[1]
config = config.to_dict()
# import pdb; pdb.set_trace()
if not config["dtype"] == "fp32":
num_bits = config["bits"]
group_size = config["group_size"]
dtype_str = "A32W{}G{}".format(num_bits, group_size)
dtype_set.add(dtype_str)
dtype_set.add("FP32")
dtype_list = list(dtype_set)
dtype_list.sort()
for op, config in tune_cfg.items():
config = config.to_dict()
op_type = op[1]
if op_type not in res.keys():
res[op_type] = {dtype: 0 for dtype in dtype_list}

# fill in results with op_type and dtype
for op, config in tune_cfg.items():
config = config.to_dict()
if config["dtype"] == "fp32":
res[op_type]["FP32"] += 1
else:
num_bits = config["bits"]
group_size = config["group_size"]
dtype_str = "A32W{}G{}".format(num_bits, group_size)
res[op_type][dtype_str] += 1

# update stats format for dump.
field_names = ["Op Type", "Total"]
field_names.extend(dtype_list)
output_data = []
for op_type in res.keys():
field_results = [op_type, sum(res[op_type].values())]
field_results.extend([res[op_type][dtype] for dtype in dtype_list])
output_data.append(field_results)

Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat()
Args:
model (object): input model
tune_cfg (dict): quantization config
Returns:
None
"""
if mode == Mode.PREPARE:
return
res = {}
# collect all dtype info and build empty results with existing op_type
dtype_set = set()
for op, config in tune_cfg.items():
op_type = op[1]
config = config.to_dict()
# import pdb; pdb.set_trace()
if not config["dtype"] == "fp32":
num_bits = config["bits"]
group_size = config["group_size"]
dtype_str = "A32W{}G{}".format(num_bits, group_size)
dtype_set.add(dtype_str)
dtype_set.add("FP32")
dtype_list = list(dtype_set)
dtype_list.sort()

for op, config in tune_cfg.items():
config = config.to_dict()
op_type = op[1]
if op_type not in res.keys():
res[op_type] = {dtype: 0 for dtype in dtype_list}

# fill in results with op_type and dtype
for op, config in tune_cfg.items():
config = config.to_dict()
if config["dtype"] == "fp32":
res[op_type]["FP32"] += 1
else:
num_bits = config["bits"]
group_size = config["group_size"]
dtype_str = "A32W{}G{}".format(num_bits, group_size)
res[op_type][dtype_str] += 1

# update stats format for dump.
field_names = ["Op Type", "Total"]
field_names.extend(dtype_list)
output_data = []
for op_type in res.keys():
field_results = [op_type, sum(res[op_type].values())]
field_results.extend([res[op_type][dtype] for dtype in dtype_list])
output_data.append(field_results)

Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat()