Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/qualcomm/_passes/insert_io_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class InsertIOQDQ(ExportPass):
exir_ops.edge.quantized_decomposed.quantize_per_channel.default: exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
# per channel (dequantize -> dequantize, for pre-quantized params)
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
# per channel group (quantize -> dequantize)
exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default: exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default,
# per channel group (dequantize -> dequantize, for pre-quantized weight params)
exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default: exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default,
}

def __init__(self, edge_program: torch.export.ExportedProgram):
Expand Down
6 changes: 6 additions & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ def get_to_edge_transform_passes(

node_visitor.q_ops.add(exir_ops.edge.torchao.quantize_affine.default)
node_visitor.dq_ops.add(exir_ops.edge.torchao.dequantize_affine.default)
node_visitor.q_ops.add(
exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default
)
node_visitor.dq_ops.add(
exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default
)

passes_job = (
passes_job if passes_job is not None else get_capture_program_passes()
Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
from executorch.backends.qualcomm.builders.utils import get_parameter
from executorch.backends.qualcomm.utils.constants import QCOM_DTYPE, QCOM_ENCODING
from executorch.backends.qualcomm.utils.constants import QCOM_DTYPE, QCOM_ENCODING, QCOM_SCALE
from executorch.exir.dialects._ops import ops as exir_ops
from torch._subclasses import FakeTensor

Expand Down Expand Up @@ -43,6 +43,8 @@ def get_quant_attrs(
# remap key for compatibility - block quantization only
if dtype := quant_attrs.get("input_dtype", None):
quant_attrs[QCOM_DTYPE] = dtype
if quant_attrs.get("scales") is not None:
quant_attrs[QCOM_SCALE] = quant_attrs["scales"]

quant_attrs[QCOM_ENCODING] = quant_node.target
return quant_attrs
Expand Down
4 changes: 4 additions & 0 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,14 @@
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default,
}

dq_ops = {
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default,
}


Expand Down Expand Up @@ -307,6 +309,8 @@ def get_quant_encoding_conf(
per_block_encoding = {
exir_ops.edge.torchao.quantize_affine.default,
exir_ops.edge.torchao.dequantize_affine.default,
exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default,
exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default,
}
if quant_attrs[QCOM_ENCODING] in per_block_encoding:
return self.make_qnn_per_block_config(node, quant_attrs)
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/partition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
torch.ops.aten.unbind.int,
torch.ops.torchao.quantize_affine.default,
torch.ops.torchao.dequantize_affine.default,
torch.ops.quantized_decomposed.quantize_per_channel_group.default,
torch.ops.quantized_decomposed.dequantize_per_channel_group.default,
]
return do_not_decompose
42 changes: 42 additions & 0 deletions backends/qualcomm/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype
from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
from executorch.backends.qualcomm.tests.models import TopKandIndex
from executorch.backends.qualcomm.utils.constants import QCOM_DTYPE, QCOM_ENCODING, QCOM_SCALE
from executorch.backends.qualcomm.utils.utils import (
generate_htp_compiler_spec,
generate_qnn_executorch_compiler_spec,
Expand Down Expand Up @@ -102,6 +103,47 @@ def test_insert_io_qdq_no_revisit(self):
# one quantize (input) and one dequantize (output) = +2 nodes.
self.assertEqual(node_count_after, node_count_before + 2)

def test_insert_io_qdq_per_channel_group_no_key_error(self):
"""InsertIOQDQ must not KeyError for per_channel_group encoded nodes."""
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS

gm, ep = self._build_quantized_graph()

# Find an existing weight-like node to reuse its meta["val"]
output_node = next(n for n in gm.graph.nodes if n.op == "output")
any_node = next(n for n in gm.graph.nodes if n.op == "placeholder")

# Inject per_channel_group quant attrs on the placeholder,
# simulating a pre-quantized weight with group quantization.
scales = torch.ones(4, 1)
any_node.meta[QCOM_QUANT_ATTRS] = {
QCOM_ENCODING: exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default,
QCOM_SCALE: scales,
QCOM_DTYPE: torch.int8,
"scales": scales,
"zero_points": None,
"quant_min": -8,
"quant_max": 7,
"dtype": torch.int8,
"group_size": 1,
"output_dtype": torch.float32,
}

# Wire that placeholder into output so InsertIOQDQ hits line 155
old_out_args = output_node.args[0]
if not isinstance(old_out_args, tuple):
old_out_args = (old_out_args,)
output_node.args = (old_out_args + (any_node,),)
gm.graph.lint()
gm.recompile()

# Should not raise KeyError
pass_instance = InsertIOQDQ(ep)
try:
pass_instance._insert(gm)
except KeyError as e:
self.fail(f"InsertIOQDQ raised KeyError for per_channel_group encoding: {e}")

def test_insert_reshape_for_argmax(self):
class ArgmaxModule(torch.nn.Module):
def forward(self, x):
Expand Down
Loading