Skip to content
Prev Previous commit
Next Next commit
Update utility.py
  • Loading branch information
violetch24 authored Jul 12, 2024
commit fcd8c77386b9eeb3ddbed8ee259b79feeb202678
45 changes: 1 addition & 44 deletions neural_compressor/torch/algorithms/static_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@

try:
import intel_extension_for_pytorch as ipex
import prettytable as pt
except: # pragma: no cover
pass

from neural_compressor.common.utils import DEFAULT_WORKSPACE, CpuInfo
from neural_compressor.torch.utils import get_ipex_version, get_torch_version, logger
from neural_compressor.torch.utils import Statistics, get_ipex_version, get_torch_version, logger

version = get_torch_version()
ipex_ver = get_ipex_version()
Expand Down Expand Up @@ -608,48 +607,6 @@ def get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_ids
return quantizable_ops


class Statistics: # pragma: no cover
"""The statistics printer."""

def __init__(self, data, header, field_names, output_handle=logger.info):
"""Init a Statistics object.

Args:
data: The statistics data
header: The table header
field_names: The field names
output_handle: The output logging method
"""
self.field_names = field_names
self.header = header
self.data = data
self.output_handle = output_handle
self.tb = pt.PrettyTable(min_table_width=40)

def print_stat(self):
"""Print the statistics."""
valid_field_names = []
for index, value in enumerate(self.field_names):
if index < 2:
valid_field_names.append(value)
continue

if any(i[index] for i in self.data):
valid_field_names.append(value)
self.tb.field_names = valid_field_names
for i in self.data:
tmp_data = []
for index, value in enumerate(i):
if self.field_names[index] in valid_field_names:
tmp_data.append(value)
if any(tmp_data[1:]):
self.tb.add_row(tmp_data)
lines = self.tb.get_string().split("\n")
self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|")
for i in lines:
self.output_handle(i)


class TransformerBasedModelBlockPatternDetector: # pragma: no cover
"""Detect the attention block and FFN block in transformer-based model."""

Expand Down