diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 7380f7a8191..aa9267cf891 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -248,16 +248,19 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict): quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS] quant_config[QCOM_SCALE_OFFSET] = scale_offset_arr - # special case for 4 bits - if ( - quant_config[QCOM_DTYPE] == torch.int8 - and quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15 - ): - quant_config[QCOM_BITWIDTH] = 4 - return ( - PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET, - quant_config, - ) + if quant_config[QCOM_DTYPE] == torch.int8: + if quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 3: + quant_config[QCOM_BITWIDTH] = 2 + return ( + PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET, + quant_config, + ) + elif quant_config[QCOM_QUANT_MAX] - quant_config[QCOM_QUANT_MIN] <= 15: + quant_config[QCOM_BITWIDTH] = 4 + return ( + PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET, + quant_config, + ) return ( PyQnnManager.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET, quant_config, @@ -338,6 +341,9 @@ def get_quant_tensor_value( if quant_configs.get(QCOM_BITWIDTH) == 4: mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8) tensor = torch.bitwise_and(mask, tensor) + elif quant_configs.get(QCOM_BITWIDTH) == 2: + mask = torch.full(tensor.size(), 0x03, dtype=torch.int8) + tensor = torch.bitwise_and(mask, tensor) return tensor def get_tensor_type( diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index b3c5edf9910..cdd0c18b0b9 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -219,6 +219,51 @@ def get_8a4w_qnn_ptq_config( return quantization_config +# 2 bits weight quantization only supports per channel and symmetric. +def get_16a2w_qnn_ptq_config( + act_symmetric: bool = False, + act_observer=MovingAverageMinMaxObserver, + eps: float = None, +) -> QuantizationConfig: + # the smallest defaults to DEFAULT_EPS_16BIT + extra_args: Dict[str, Any] = {"eps": eps if eps else DEFAULT_EPS_16BIT} + + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=( + torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine + ), + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-2, + quant_max=1, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + # 4 bits quantization only supports specific ops. def get_16a4w_qnn_ptq_config( act_symmetric: bool = False, @@ -435,7 +480,7 @@ def get_ptq_per_channel_quant_config( torch.int8, torch.int16, } - supported_weight_dtypes = {torch.int4, torch.int8, torch.int16} + supported_weight_dtypes = {torch.int2, torch.int4, torch.int8, torch.int16} assert ( act_dtype in supported_act_types ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" @@ -468,12 +513,23 @@ def get_ptq_per_channel_quant_config( observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), ) + q_dtype = weight_dtype + if weight_dtype == torch.int4: + q_dtype = torch.int8 + q_min = -7 + q_max = 7 + elif weight_dtype == torch.int2: + q_dtype = torch.int8 + q_min = -2 + q_max = 1 + else: + q_min = torch.iinfo(weight_dtype).min + 1 + q_max = torch.iinfo(weight_dtype).max + weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, - quant_min=( - -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 - ), - quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, + dtype=q_dtype, + quant_min=q_min, + quant_max=q_max, qscheme=torch.per_channel_symmetric, ch_axis=ch_axis, observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(**extra_args), diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 5d297ef14c4..8688d6f2773 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -44,6 +44,7 @@ from .qconfig import ( get_16a16w_qnn_ptq_config, + get_16a2w_qnn_ptq_config, get_16a4w_qnn_ptq_config, get_16a4w_qnn_qat_config, get_16a8w_qnn_ptq_config, @@ -65,6 +66,7 @@ __all__ = [ "QnnQuantizer", "QuantDtype", + "get_16a2w_qnn_ptq_config", "get_16a4w_qnn_ptq_config", "get_16a8w_qnn_ptq_config", "get_16a8w_qnn_qat_config", @@ -89,6 +91,7 @@ class QuantDtype(IntEnum): use_16a4w_block = 3 use_8a8w = 4 use_8a4w = 5 + use_16a2w = 6 QUANT_CONFIG_DICT = { @@ -120,6 +123,15 @@ class QuantDtype(IntEnum): ), None, ), + (QuantDtype.use_16a2w, False): ( + get_16a2w_qnn_ptq_config, + partial( + get_ptq_per_channel_quant_config, + act_dtype=torch.uint16, + weight_dtype=torch.int2, + ), + None, + ), (QuantDtype.use_16a4w_block, False): ( get_16a4w_qnn_ptq_config, partial( diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 688dddf5c2a..d664b9871e8 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -2468,6 +2468,35 @@ def setUp(self): shared_buffer=TestQNN.shared_buffer, ) + def test_qnn_backend_16a2w_conv2d(self): + modules = [Conv2dSingle(), Conv2dSingle(bias=False)] # noqa: F405 + sample_input = (torch.randn([1, 1, 3, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + qdq_module = self.get_qdq_module( + module, + sample_input, + is_linear_per_channel=True, + quant_dtype=QuantDtype.use_16a2w, + ) + self.lower_module_and_test_output(qdq_module, sample_input) + + def test_qnn_backend_16a2w_linear(self): + sample_input = (torch.randn([3, 512]),) + for i, per_channel, use_bias in [ + (1, True, False), + (2, True, True), + ]: + with self.subTest(i=i): + module = Linear(use_bias=use_bias) # noqa: F405 + qdq_module = self.get_qdq_module( + module, + sample_input, + is_linear_per_channel=per_channel, + quant_dtype=QuantDtype.use_16a2w, + ) + self.lower_module_and_test_output(qdq_module, sample_input) + def test_qnn_backend_16a4w_conv2d(self): modules = [Conv2dSingle(), Conv2dSingle(bias=False)] # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),)