Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3579,6 +3579,16 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
return q_model

self.tune_cfg["fx_sub_module_list"] = self.sub_module_list

# BF16 fallback
if (
len(self.tune_cfg["bf16_ops_list"]) > 0
and self.version.release >= Version("1.11.0").release
and self.use_bf16
and (CpuInfo().bf16 or os.getenv("FORCE_BF16") == "1")
): # pragma: no cover
q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg)

if self.approach == "quant_aware_training":
q_model._model.train()
if self.sub_module_list is None:
Expand Down Expand Up @@ -3665,14 +3675,6 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
self.sub_module_list, q_model._model, prefix="", custom_config=self.prepare_custom_config_dict
)

if (
len(self.tune_cfg["bf16_ops_list"]) > 0
and self.version.release >= Version("1.11.0").release
and self.use_bf16
and (CpuInfo().bf16 or os.getenv("FORCE_BF16") == "1")
): # pragma: no cover
q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg)

self.fused_dict = self.get_fused_list(q_model.model)
q_model.is_quantized = True
q_model.q_config = copy.deepcopy(self.tune_cfg)
Expand Down
25 changes: 0 additions & 25 deletions neural_compressor/adaptor/torch_utils/bf16_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""Bf16 Convert for Torch Utils."""
import torch
import torch.nn as nn
from torch.fx import symbolic_trace

from ...utils import logger

Expand Down Expand Up @@ -58,8 +57,6 @@ def Convert(model, tune_cfg):
if len(bf16_ops_list) > 0:
logger.info("Convert operators to bfloat16")
mixed_precision_model = _bf16_wrapper_model(model, bf16_ops_list)
if fx_sub_module_list is not None and len(fx_sub_module_list) > 0:
mixed_precision_model = bf16_symbolic_trace(mixed_precision_model, fx_sub_module_list)
return mixed_precision_model


Expand All @@ -73,25 +70,3 @@ def _bf16_wrapper_model(model, bf16_ops_list, prefix=""):
_bf16_wrapper_model(child, bf16_ops_list, op_name)
setattr(model, name, child)
return model


def bf16_symbolic_trace(model, fx_sub_module_list, prefix=""):
"""Symbolic trace for bf16 models.

Args:
model (object): the input model.
fx_sub_module_list (list): _description_
prefix (str): prefix of op name.

Returns:
model (object)
"""
for name, child in model.named_children():
op_name = prefix + "." + name if prefix != "" else name
for fx_sub_module_name in fx_sub_module_list:
if op_name == fx_sub_module_name:
child = symbolic_trace(child)
else:
bf16_symbolic_trace(child, fx_sub_module_list, op_name)
setattr(model, name, child)
return model