From 79a29daf81d78c5425fb8dce5f5e12645e06817c Mon Sep 17 00:00:00 2001 From: "hyungkeun.park" Date: Mon, 18 May 2026 08:47:51 +0000 Subject: [PATCH] Add quantize/dequantize_per_channel_group support to QNN backend quantized_decomposed.quantize_per_channel_group and dequantize_per_channel_group are used for LLM weight-only quantization (e.g. int4 group-wise) but were not recognized by the QNN backend, causing the ops to be decomposed or failing with a KeyError in InsertIOQDQ. Five files are changed: - builders/node_visitor.py: add both ops to q_ops/dq_ops and to the per_block_encoding set so get_quant_encoding_conf routes them through make_qnn_per_block_config, matching the existing torchao affine path. - partition/utils.py: add both ops to get_skip_decomp_table so they are preserved as-is during torch.export and reach the backend. - _passes/utils.py: remap "scales" -> QCOM_SCALE in get_quant_attrs, parallel to the existing "input_dtype" -> QCOM_DTYPE remap, so AnnotateQuantAttrs correctly propagates the scale tensor for per-channel-group nodes. - _passes/qnn_pass_manager.py: dynamically register both ops into node_visitor.q_ops/dq_ops in get_to_edge_transform_passes, following the same pattern used for torchao ops. - _passes/insert_io_qdq.py: add per-channel-group entries to q_dq_map to fix KeyError when a pre-quantized weight node with dequantize_per_channel_group encoding feeds the graph output. A unit test is added to test_passes.py that injects per_channel_group quant attrs onto a node feeding the output and verifies InsertIOQDQ completes without KeyError. --- backends/qualcomm/_passes/insert_io_qdq.py | 4 ++ backends/qualcomm/_passes/qnn_pass_manager.py | 6 +++ backends/qualcomm/_passes/utils.py | 4 +- backends/qualcomm/builders/node_visitor.py | 4 ++ backends/qualcomm/partition/utils.py | 2 + backends/qualcomm/tests/test_passes.py | 42 +++++++++++++++++++ 6 files changed, 61 insertions(+), 1 deletion(-) diff --git a/backends/qualcomm/_passes/insert_io_qdq.py b/backends/qualcomm/_passes/insert_io_qdq.py index 0c6e539f0b8..edc23e38895 100644 --- a/backends/qualcomm/_passes/insert_io_qdq.py +++ b/backends/qualcomm/_passes/insert_io_qdq.py @@ -42,6 +42,10 @@ class InsertIOQDQ(ExportPass): exir_ops.edge.quantized_decomposed.quantize_per_channel.default: exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, # per channel (dequantize -> dequantize, for pre-quantized params) exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + # per channel group (quantize -> dequantize) + exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default: exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default, + # per channel group (dequantize -> dequantize, for pre-quantized weight params) + exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default: exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default, } def __init__(self, edge_program: torch.export.ExportedProgram): diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 57354af11de..bf636b0254a 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -168,6 +168,12 @@ def get_to_edge_transform_passes( node_visitor.q_ops.add(exir_ops.edge.torchao.quantize_affine.default) node_visitor.dq_ops.add(exir_ops.edge.torchao.dequantize_affine.default) + node_visitor.q_ops.add( + exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default + ) + node_visitor.dq_ops.add( + exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default + ) passes_job = ( passes_job if passes_job is not None else get_capture_program_passes() diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 04371d61e1c..746622580cc 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -8,7 +8,7 @@ import torch from executorch.backends.qualcomm.builders.utils import get_parameter -from executorch.backends.qualcomm.utils.constants import QCOM_DTYPE, QCOM_ENCODING +from executorch.backends.qualcomm.utils.constants import QCOM_DTYPE, QCOM_ENCODING, QCOM_SCALE from executorch.exir.dialects._ops import ops as exir_ops from torch._subclasses import FakeTensor @@ -43,6 +43,8 @@ def get_quant_attrs( # remap key for compatibility - block quantization only if dtype := quant_attrs.get("input_dtype", None): quant_attrs[QCOM_DTYPE] = dtype + if quant_attrs.get("scales") is not None: + quant_attrs[QCOM_SCALE] = quant_attrs["scales"] quant_attrs[QCOM_ENCODING] = quant_node.target return quant_attrs diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 7380f7a8191..ca14894ce11 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -90,12 +90,14 @@ exir_ops.edge.quantized_decomposed.quantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default, } dq_ops = { exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default, } @@ -307,6 +309,8 @@ def get_quant_encoding_conf( per_block_encoding = { exir_ops.edge.torchao.quantize_affine.default, exir_ops.edge.torchao.dequantize_affine.default, + exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default, } if quant_attrs[QCOM_ENCODING] in per_block_encoding: return self.make_qnn_per_block_config(node, quant_attrs) diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index a83444a56b2..a46c67a9ac5 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -74,5 +74,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: torch.ops.aten.unbind.int, torch.ops.torchao.quantize_affine.default, torch.ops.torchao.dequantize_affine.default, + torch.ops.quantized_decomposed.quantize_per_channel_group.default, + torch.ops.quantized_decomposed.dequantize_per_channel_group.default, ] return do_not_decompose diff --git a/backends/qualcomm/tests/test_passes.py b/backends/qualcomm/tests/test_passes.py index 1f007628e61..9c07243c980 100644 --- a/backends/qualcomm/tests/test_passes.py +++ b/backends/qualcomm/tests/test_passes.py @@ -13,6 +13,7 @@ from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.tests.models import TopKandIndex +from executorch.backends.qualcomm.utils.constants import QCOM_DTYPE, QCOM_ENCODING, QCOM_SCALE from executorch.backends.qualcomm.utils.utils import ( generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, @@ -102,6 +103,47 @@ def test_insert_io_qdq_no_revisit(self): # one quantize (input) and one dequantize (output) = +2 nodes. self.assertEqual(node_count_after, node_count_before + 2) + def test_insert_io_qdq_per_channel_group_no_key_error(self): + """InsertIOQDQ must not KeyError for per_channel_group encoded nodes.""" + from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS + + gm, ep = self._build_quantized_graph() + + # Find an existing weight-like node to reuse its meta["val"] + output_node = next(n for n in gm.graph.nodes if n.op == "output") + any_node = next(n for n in gm.graph.nodes if n.op == "placeholder") + + # Inject per_channel_group quant attrs on the placeholder, + # simulating a pre-quantized weight with group quantization. + scales = torch.ones(4, 1) + any_node.meta[QCOM_QUANT_ATTRS] = { + QCOM_ENCODING: exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default, + QCOM_SCALE: scales, + QCOM_DTYPE: torch.int8, + "scales": scales, + "zero_points": None, + "quant_min": -8, + "quant_max": 7, + "dtype": torch.int8, + "group_size": 1, + "output_dtype": torch.float32, + } + + # Wire that placeholder into output so InsertIOQDQ hits line 155 + old_out_args = output_node.args[0] + if not isinstance(old_out_args, tuple): + old_out_args = (old_out_args,) + output_node.args = (old_out_args + (any_node,),) + gm.graph.lint() + gm.recompile() + + # Should not raise KeyError + pass_instance = InsertIOQDQ(ep) + try: + pass_instance._insert(gm) + except KeyError as e: + self.fail(f"InsertIOQDQ raised KeyError for per_channel_group encoding: {e}") + def test_insert_reshape_for_argmax(self): class ArgmaxModule(torch.nn.Module): def forward(self, x):