Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,16 +248,19 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]

quant_config[QCOM_SCALE_OFFSET] = scale_offset_arr
# special case for 4 bits
if (
quant_config[QCOM_DTYPE] == torch.int8
and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15
):
quant_config[QCOM_BITWIDTH] = 4
return (
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET,
quant_config,
)
if quant_config[QCOM_DTYPE] == torch.int8:
if quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 3:
quant_config[QCOM_BITWIDTH] = 2
return (
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET,
quant_config,
)
elif quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15:
quant_config[QCOM_BITWIDTH] = 4
return (
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET,
quant_config,
)
return (
PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET,
quant_config,
Expand Down Expand Up @@ -338,6 +341,9 @@ def get_quant_tensor_value(
if quant_configs.get(QCOM_BITWIDTH) == 4:
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
tensor = torch.bitwise_and(mask, tensor)
elif quant_configs.get(QCOM_BITWIDTH) == 2:
mask = torch.full(tensor.size(), 0x03, dtype=torch.int8)
tensor = torch.bitwise_and(mask, tensor)
return tensor

def get_tensor_type(
Expand Down
68 changes: 62 additions & 6 deletions backends/qualcomm/quantizer/qconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,51 @@ def get_8a4w_qnn_ptq_config(
return quantization_config


# 2 bits weight quantization only supports per channel and symmetric.
def get_16a2w_qnn_ptq_config(
act_symmetric: bool = False,
act_observer=MovingAverageMinMaxObserver,
eps: float = None,
) -> QuantizationConfig:
# the smallest defaults to DEFAULT_EPS_16BIT
extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT}

act_quantization_spec = QuantizationSpec(
dtype=torch.int32,
quant_min=torch.iinfo(torch.uint16).min,
quant_max=torch.iinfo(torch.uint16).max,
qscheme=(
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
),
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
)

weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-2,
quant_max=1,
qscheme=torch.per_tensor_symmetric,
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
)

bias_quantization_spec = QuantizationSpec(
dtype=torch.int32,
quant_min=torch.iinfo(torch.int32).min,
quant_max=torch.iinfo(torch.int32).max,
qscheme=torch.per_tensor_symmetric,
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
)

quantization_config = QuantizationConfig(
input_activation=act_quantization_spec,
output_activation=act_quantization_spec,
weight=weight_quantization_spec,
bias=bias_quantization_spec,
)

return quantization_config


# 4 bits quantization only supports specific ops.
def get_16a4w_qnn_ptq_config(
act_symmetric: bool = False,
Expand Down Expand Up @@ -435,7 +480,7 @@ def get_ptq_per_channel_quant_config(
torch.int8,
torch.int16,
}
supported_weight_dtypes = {torch.int4, torch.int8, torch.int16}
supported_weight_dtypes = {torch.int2, torch.int4, torch.int8, torch.int16}
assert (
act_dtype in supported_act_types
), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}"
Expand Down Expand Up @@ -468,12 +513,23 @@ def get_ptq_per_channel_quant_config(
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
)

q_dtype = weight_dtype
if weight_dtype == torch.int4:
q_dtype = torch.int8
q_min = -7
q_max = 7
elif weight_dtype == torch.int2:
q_dtype = torch.int8
q_min = -2
q_max = 1
else:
q_min = torch.iinfo(weight_dtype).min + 1
q_max = torch.iinfo(weight_dtype).max

weight_quantization_spec = QuantizationSpec(
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
quant_min=(
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
),
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
dtype=q_dtype,
quant_min=q_min,
quant_max=q_max,
qscheme=torch.per_channel_symmetric,
ch_axis=ch_axis,
observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(**extra_args),
Expand Down
12 changes: 12 additions & 0 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

from .qconfig import (
get_16a16w_qnn_ptq_config,
get_16a2w_qnn_ptq_config,
get_16a4w_qnn_ptq_config,
get_16a4w_qnn_qat_config,
get_16a8w_qnn_ptq_config,
Expand All @@ -65,6 +66,7 @@
__all__ = [
"QnnQuantizer",
"QuantDtype",
"get_16a2w_qnn_ptq_config",
"get_16a4w_qnn_ptq_config",
"get_16a8w_qnn_ptq_config",
"get_16a8w_qnn_qat_config",
Expand All @@ -89,6 +91,7 @@ class QuantDtype(IntEnum):
use_16a4w_block = 3
use_8a8w = 4
use_8a4w = 5
use_16a2w = 6


QUANT_CONFIG_DICT = {
Expand Down Expand Up @@ -120,6 +123,15 @@ class QuantDtype(IntEnum):
),
None,
),
(QuantDtype.use_16a2w, False): (
get_16a2w_qnn_ptq_config,
partial(
get_ptq_per_channel_quant_config,
act_dtype=torch.uint16,
weight_dtype=torch.int2,
),
None,
),
(QuantDtype.use_16a4w_block, False): (
get_16a4w_qnn_ptq_config,
partial(
Expand Down
29 changes: 29 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2468,6 +2468,35 @@ def setUp(self):
shared_buffer=TestQNN.shared_buffer,
)

def test_qnn_backend_16a2w_conv2d(self):
modules = [Conv2dSingle(), Conv2dSingle(bias=False)] # noqa: F405
sample_input = (torch.randn([1, 1, 3, 3]),)
for i, module in enumerate(modules):
with self.subTest(i=i):
qdq_module = self.get_qdq_module(
module,
sample_input,
is_linear_per_channel=True,
quant_dtype=QuantDtype.use_16a2w,
)
self.lower_module_and_test_output(qdq_module, sample_input)

def test_qnn_backend_16a2w_linear(self):
sample_input = (torch.randn([3, 512]),)
for i, per_channel, use_bias in [
(1, True, False),
(2, True, True),
]:
with self.subTest(i=i):
module = Linear(use_bias=use_bias) # noqa: F405
qdq_module = self.get_qdq_module(
module,
sample_input,
is_linear_per_channel=per_channel,
quant_dtype=QuantDtype.use_16a2w,
)
self.lower_module_and_test_output(qdq_module, sample_input)

def test_qnn_backend_16a4w_conv2d(self):
modules = [Conv2dSingle(), Conv2dSingle(bias=False)] # noqa: F405
sample_input = (torch.randn([1, 1, 3, 3]),)
Expand Down
Loading