Skip to content
Prev Previous commit
Next Next commit
update doc and UTs for set_local
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he committed Jun 21, 2024
commit 98ae9167a19e46a2dfef143bfc56c73c961a21f5
21 changes: 21 additions & 0 deletions docs/3x/PyTorch.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,24 @@ def load(output_dir="./saved_results", model=None):
</tr>
</tbody>
</table>

2. How to set different configuration for specific op_name or op_type?
> INC extends a `set_local` method based on the global configuration object to set custom configuration.

```python
def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig:
"""Set custom configuration based on the global configuration object.

Args:
operator_name_or_list (Union[List, str, Callable]): specific operator
config (BaseConfig): specific configuration
"""
```

> Demo:

```python
quant_config = RTNConfig() # Initialize global configuration with default bits=4
quant_config.set_local(".*mlp.*", RTNConfig(bits=8)) # For layers with "mlp" in their names, set bits=8
quant_config.set_local("Conv1d", RTNConfig(dtype="fp32")) # For Conv1d layers, do not quantize them.
```
9 changes: 9 additions & 0 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ def local_config(self, config):
self._local_config = config

def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig:
"""Set custom configuration based on the global configuration object.

Args:
operator_name_or_list (Union[List, str, Callable]): specific operator
config (BaseConfig): specific configuration

Returns:
Updated Config
"""
if isinstance(operator_name_or_list, list):
for operator_name in operator_name_or_list:
if operator_name in self.local_config:
Expand Down
13 changes: 13 additions & 0 deletions test/3x/torch/quantization/weight_only/test_rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,19 @@ def test_dtype_params(self, dtype):
assert torch.allclose(out, self.label, atol=0.11), "Accuracy gap atol > 0.11 is unexpected."
assert torch.allclose(out, out_next), "output should be same"

def test_mix_dtype(self):
model = copy.deepcopy(self.tiny_gptj)
quant_config = RTNConfig()
quant_config.set_local(".*mlp.*", RTNConfig(bits=8))
quant_config.set_local(".*.out_proj", RTNConfig(bits=6))
quant_config.set_local(".*.k_proj", RTNConfig(dtype="nf4"))
model = prepare(model, quant_config)
model = convert(model)
out = model(self.example_inputs)[0]
out_next = model(self.example_inputs)[0]
assert torch.allclose(out, self.label, atol=0.08), "Accuracy gap atol > 0.08 is unexpected."
assert torch.allclose(out, out_next), "output should be same"

@pytest.mark.parametrize("dtype", ["int4", "nf4"])
@pytest.mark.parametrize("double_quant_bits", [6])
@pytest.mark.parametrize("double_quant_group_size", [8, 256])
Expand Down