Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 95 additions & 2 deletions backends/arm/test/quantizer/test_selective_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict

import torch

from executorch.backends.arm.quantizer import (
get_symmetric_a16w8_quantization_config,
get_symmetric_quantization_config,
Expand All @@ -16,13 +17,17 @@
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import QuantizationPipeline
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.test.harness.stages import StageType
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
from torchvision import models, transforms # type: ignore[import-untyped]
from torchvision.ops.misc import Conv2dNormActivation # type: ignore[import-untyped]


def get_quantizer():
def get_quantizer(use_composable_quantizer: bool = False):
tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
quantizer = TOSAQuantizer(tosa_spec)
quantizer = TOSAQuantizer(
tosa_spec, use_composable_quantizer=use_composable_quantizer
)
quantizer.set_global(get_symmetric_quantization_config())
return quantizer

Expand Down Expand Up @@ -53,6 +58,25 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y


class Cat(torch.nn.Module):

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.cat((x, y), dim=1)


class LinearGraphTail(torch.nn.Module):

def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear(x)
x = torch.relu(x)
x = torch.sigmoid(x)
return torch.neg(x)


class AddSoftmaxAdd(torch.nn.Module):
module_names = {"add_0": None, "add_1": None}
module_types = {
Expand Down Expand Up @@ -131,6 +155,75 @@ def test_selective_quant_module_type_tosa_INT(model):
pipeline.run()


def test_selective_quant_cat_node_target_none_tosa_INT():
model = Cat()
inputs = (torch.randn(1, 2, 4), torch.randn(1, 3, 4))

quantizer = get_quantizer(use_composable_quantizer=True)
quantizer.set_node_target(torch.ops.aten.cat.default, None)

pipeline = QuantizationPipeline[tuple[torch.Tensor, torch.Tensor]](
model,
inputs,
quantizer=quantizer,
qspecs={
"aten.cat.default": {
None: 1,
},
},
Comment on lines +169 to +173
)

pipeline.run()


def test_composable_io_none_skips_global_tosa_INT():
model = Add()
inputs = (torch.randn(1, 10), torch.randn(1, 10))

quantizer = get_quantizer(use_composable_quantizer=True)
quantizer.set_io(None)

pipeline = QuantizationPipeline[tuple[torch.Tensor, torch.Tensor]](
model,
inputs,
quantizer=quantizer,
input_qspecs={None: 2},
output_qspecs={None: 1},
)

pipeline.run()


def test_composable_global_none_linear_graph_tail_tosa_INT():
model = LinearGraphTail()
inputs = (torch.randn(1, 10),)

quantizer = get_quantizer(use_composable_quantizer=True)
quantizer.set_global(None)

pipeline = QuantizationPipeline[tuple[torch.Tensor]](
model,
inputs,
quantizer=quantizer,
qspecs={
"aten.linear.default": {None: 1},
"aten.relu.default": {None: 1},
"aten.sigmoid.default": {None: 1},
"aten.neg.default": {None: 1},
},
)

pipeline.run()

graph = pipeline.tester.get_graph(StageType.QUANTIZE)
unannotated_nodes = [
node.name
for node in graph.nodes
if node.op == "call_function" and Q_ANNOTATION_KEY not in node.meta
]
assert not unannotated_nodes


mv3 = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights)
mv3.eval()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Expand Down
18 changes: 14 additions & 4 deletions backends/cortex_m/quantizer/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,25 @@ def _get_match(self, node_queue: List[Node]) -> List[Node]:
return []

def _get_matches(
self, node_queue: List[Node], quantization_config: QuantizationConfig
self, node_queue: List[Node], quantization_config: Optional[QuantizationConfig]
) -> List[PatternMatchResult]:
"""Returns the longest accepted match starting at the first node of the
queue as well as longer rejected matches.
"""
# Annotating with None means rejecting quantization - this is always supported.
if quantization_config is None:
node = node_queue[0]
if node.meta.get(self.Q_PATTERN_MATCHED_KEY, False):
return [
PatternMatchResult([node], False, self.REJECT_PREVIOUSLY_ANNOTATED)
]

node.meta[self.Q_PATTERN_MATCHED_KEY] = True
return [PatternMatchResult([node], True)]
Comment on lines +124 to +130

matches: list[PatternMatchResult] = []
accepted = False
max_match_length = len(node_queue)

while max_match_length > 0 and not accepted:
match = self._get_match(node_queue[:max_match_length])
max_match_length = (
Expand All @@ -136,7 +146,7 @@ def _get_matches(
return matches

def _dequeue_and_get_matches(
self, node_queue: List[Node], quantization_config: QuantizationConfig
self, node_queue: List[Node], quantization_config: Optional[QuantizationConfig]
) -> List[PatternMatchResult]:
"""Dequeues the longest accepted match starting at the first node of the
queue, and returns all potential matches that were checked (rejected
Expand All @@ -160,7 +170,7 @@ def _dequeue_and_get_matches(
return potential_matches

def find_pattern_matches(
self, nodes: Iterator[Node], quantization_config: QuantizationConfig
self, nodes: Iterator[Node], quantization_config: Optional[QuantizationConfig]
) -> Iterator[PatternMatchResult]:
"""Match all given patterns in the graph and return match results with
acceptance/rejection status. Each node can only be part of one match,
Expand Down
Loading