diff --git a/backends/arm/test/quantizer/test_selective_quantization.py b/backends/arm/test/quantizer/test_selective_quantization.py index a59a509ce06..ef0c51c58ca 100644 --- a/backends/arm/test/quantizer/test_selective_quantization.py +++ b/backends/arm/test/quantizer/test_selective_quantization.py @@ -7,6 +7,7 @@ from typing import Dict import torch + from executorch.backends.arm.quantizer import ( get_symmetric_a16w8_quantization_config, get_symmetric_quantization_config, @@ -16,13 +17,17 @@ from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import QuantizationPipeline from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.test.harness.stages import StageType +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY from torchvision import models, transforms # type: ignore[import-untyped] from torchvision.ops.misc import Conv2dNormActivation # type: ignore[import-untyped] -def get_quantizer(): +def get_quantizer(use_composable_quantizer: bool = False): tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") - quantizer = TOSAQuantizer(tosa_spec) + quantizer = TOSAQuantizer( + tosa_spec, use_composable_quantizer=use_composable_quantizer + ) quantizer.set_global(get_symmetric_quantization_config()) return quantizer @@ -53,6 +58,25 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y +class Cat(torch.nn.Module): + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.cat((x, y), dim=1) + + +class LinearGraphTail(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear(x) + x = torch.relu(x) + x = torch.sigmoid(x) + return torch.neg(x) + + class AddSoftmaxAdd(torch.nn.Module): module_names = {"add_0": None, "add_1": None} module_types = { @@ -131,6 +155,75 @@ def test_selective_quant_module_type_tosa_INT(model): pipeline.run() +def test_selective_quant_cat_node_target_none_tosa_INT(): + model = Cat() + inputs = (torch.randn(1, 2, 4), torch.randn(1, 3, 4)) + + quantizer = get_quantizer(use_composable_quantizer=True) + quantizer.set_node_target(torch.ops.aten.cat.default, None) + + pipeline = QuantizationPipeline[tuple[torch.Tensor, torch.Tensor]]( + model, + inputs, + quantizer=quantizer, + qspecs={ + "aten.cat.default": { + None: 1, + }, + }, + ) + + pipeline.run() + + +def test_composable_io_none_skips_global_tosa_INT(): + model = Add() + inputs = (torch.randn(1, 10), torch.randn(1, 10)) + + quantizer = get_quantizer(use_composable_quantizer=True) + quantizer.set_io(None) + + pipeline = QuantizationPipeline[tuple[torch.Tensor, torch.Tensor]]( + model, + inputs, + quantizer=quantizer, + input_qspecs={None: 2}, + output_qspecs={None: 1}, + ) + + pipeline.run() + + +def test_composable_global_none_linear_graph_tail_tosa_INT(): + model = LinearGraphTail() + inputs = (torch.randn(1, 10),) + + quantizer = get_quantizer(use_composable_quantizer=True) + quantizer.set_global(None) + + pipeline = QuantizationPipeline[tuple[torch.Tensor]]( + model, + inputs, + quantizer=quantizer, + qspecs={ + "aten.linear.default": {None: 1}, + "aten.relu.default": {None: 1}, + "aten.sigmoid.default": {None: 1}, + "aten.neg.default": {None: 1}, + }, + ) + + pipeline.run() + + graph = pipeline.tester.get_graph(StageType.QUANTIZE) + unannotated_nodes = [ + node.name + for node in graph.nodes + if node.op == "call_function" and Q_ANNOTATION_KEY not in node.meta + ] + assert not unannotated_nodes + + mv3 = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights) mv3.eval() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) diff --git a/backends/cortex_m/quantizer/pattern_matcher.py b/backends/cortex_m/quantizer/pattern_matcher.py index 3694fd22a6c..6e09fdbe58f 100644 --- a/backends/cortex_m/quantizer/pattern_matcher.py +++ b/backends/cortex_m/quantizer/pattern_matcher.py @@ -113,15 +113,25 @@ def _get_match(self, node_queue: List[Node]) -> List[Node]: return [] def _get_matches( - self, node_queue: List[Node], quantization_config: QuantizationConfig + self, node_queue: List[Node], quantization_config: Optional[QuantizationConfig] ) -> List[PatternMatchResult]: """Returns the longest accepted match starting at the first node of the queue as well as longer rejected matches. """ + # Annotating with None means rejecting quantization - this is always supported. + if quantization_config is None: + node = node_queue[0] + if node.meta.get(self.Q_PATTERN_MATCHED_KEY, False): + return [ + PatternMatchResult([node], False, self.REJECT_PREVIOUSLY_ANNOTATED) + ] + + node.meta[self.Q_PATTERN_MATCHED_KEY] = True + return [PatternMatchResult([node], True)] + matches: list[PatternMatchResult] = [] accepted = False max_match_length = len(node_queue) - while max_match_length > 0 and not accepted: match = self._get_match(node_queue[:max_match_length]) max_match_length = ( @@ -136,7 +146,7 @@ def _get_matches( return matches def _dequeue_and_get_matches( - self, node_queue: List[Node], quantization_config: QuantizationConfig + self, node_queue: List[Node], quantization_config: Optional[QuantizationConfig] ) -> List[PatternMatchResult]: """Dequeues the longest accepted match starting at the first node of the queue, and returns all potential matches that were checked (rejected @@ -160,7 +170,7 @@ def _dequeue_and_get_matches( return potential_matches def find_pattern_matches( - self, nodes: Iterator[Node], quantization_config: QuantizationConfig + self, nodes: Iterator[Node], quantization_config: Optional[QuantizationConfig] ) -> Iterator[PatternMatchResult]: """Match all given patterns in the graph and return match results with acceptance/rejection status. Each node can only be part of one match,