Skip to content

Arm backend: validate TOSA RESIZE legality constraints consistently #19631

@Rob-Hughes-Arm

Description

@Rob-Hughes-Arm

Follow-up to #19069 / #19151.

The Arm backend now appears to reject the originally reported bilinear align_corners=False exact 1/16 downscale case. That case lowered to:

scales = [2, 32, 2, 32]
offset = [15, 15]
border = [-15, -15]
mode = BILINEAR

and violated the TOSA rule:

scale_y_d < 16 * scale_y_n
scale_x_d < 16 * scale_x_n

However, the current support checks appear to validate only this specific bilinear case. TOSA RESIZE legality is more general than that. In particular, the lower-bound downscale rule applies to the RESIZE operation itself, not only to bilinear resize.

For example, an exact nearest-neighbour 1/16 downscale appears able to produce the same invalid scale tuple:

scales = [2, 32, 2, 32]
offset = [15, 15]
border = [-15, -15]
mode = NEAREST

This would violate the same TOSA rule:

32 >= 16 * 2

but the current nearest-neighbour support check appears to allow nearest resize unconditionally.

Why this matters

This is a backend correctness concern before any downstream compiler or runtime is involved.

The Arm backend should not emit a TOSA RESIZE op that violates the TOSA specification. If a resize cannot be represented legally as TOSA RESIZE, the backend should reject it, undelegate it, or lower it through some valid alternative.

The previous fix addressed the originally reported bilinear case, but TOSA RESIZE has several common legality constraints that should be checked consistently for all modes and all paths that emit RESIZE.

Current upstream behavior that looks incomplete

Current backends/arm/operator_support/upsample_support.py contains a support gate for nearest resize that appears to return True unconditionally:

@register_tosa_support_check
class UpsampleNearest2dSupported(SupportedTOSAOperatorCheck):
    """Provide the explicit TOSA support gate for nearest upsample."""

    targets = [exir_ops.edge.aten.upsample_nearest2d.vec]

    def is_node_tosa_supported(
        self, _node: fx.Node, _tosa_spec: TosaSpecification
    ) -> bool:
        # type: ignore[override, misc]
        return True

The bilinear support checker includes a lower-bound downscale check:

if scale_y_d >= 16 * scale_y_n or scale_x_d >= 16 * scale_x_n:
    self.reporter.report_reject(
        node,
        "Bilinear RESIZE downscale must be strictly greater than 1/16",
    )
    return False

Similarly, backends/arm/tosa/dialect/ops/resize.py appears to apply the lower-bound downscale check only when resize_mode == "bilinear":

if resize_mode == "bilinear":
    scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale
    if scale_y_d >= 16 * scale_y_n or scale_x_d >= 16 * scale_x_n:
        raise TosaValueError(
            "Bilinear RESIZE downscale must be strictly greater than 1/16",
            op="RESIZE",
        )

But TOSA RESIZE supports both NEAREST and BILINEAR, and the relevant legality checks are part of the common RESIZE operation validation.

TOSA requirements that should be checked generically

TOSA 1.0.1 RESIZE has common validity checks including, at least:

LEVEL_CHECK(scale_y_n / scale_y_d <= MAX_SCALE)
LEVEL_CHECK(scale_x_n / scale_x_d <= MAX_SCALE)

ERROR_IF(max(OH, OW, IH, IW) >= 16384)
ERROR_IF(scale_y_n <= 0 || scale_y_d <= 0 || scale_x_n <= 0 || scale_x_d <= 0)
ERROR_IF(scale_y_n > (1 << 11) || scale_x_n > (1 << 11))
ERROR_IF(scale_y_d >= 16 * scale_y_n || scale_x_d >= 16 * scale_x_n)
ERROR_IF(offset_y < -scale_y_n || offset_y >= 16 * scale_y_n)
ERROR_IF(offset_x < -scale_x_n || offset_x >= 16 * scale_x_n)
ERROR_IF(border_y < -16 * scale_y_n || border_y >= scale_y_n)
ERROR_IF(border_x < -16 * scale_x_n || border_x >= scale_x_n)
ERROR_IF(OH != idiv_check((IH - 1) * scale_y_n - offset_y + border_y, scale_y_d) + 1)
ERROR_IF(OW != idiv_check((IW - 1) * scale_x_n - offset_x + border_x, scale_x_d) + 1)

The current support and fake-op validation paths appear to check only a subset of these. In particular, they appear to miss or apply inconsistently:

  • nearest-neighbour lower-bound downscale validation
  • MAX_SCALE / level-bound validation
  • offset range validation
  • output-shape consistency validation
  • possibly scale_n > (1 << 11), where not already covered elsewhere
  • possibly max(OH, OW, IH, IW) >= 16384

Minimal repro candidate: nearest exact 1/16 downscale

import shutil
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F

from executorch.backends.arm.quantizer import (
    VgfQuantizer,
    get_symmetric_quantization_config,
)
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
from executorch.backends.arm.tosa.specification import TosaSpecification
from executorch.exir import to_edge_transform_and_lower
from executorch.exir.capture._config import EdgeCompileConfig


ARTIFACT_DIR = Path("artifacts/tiny_nearest_resize_invalid_tosa")


class TinyNearestResizeProbe(nn.Module):
    def forward(self, x):
        return F.interpolate(
            x,
            scale_factor=1.0 / 16.0,
            mode="nearest",
        )


def strip_unused_guard_nodes(graph_module):
    for node in list(graph_module.graph.nodes):
        if node.op == "call_module" and node.target == "_guards_fn" and len(node.users) == 0:
            graph_module.graph.erase_node(node)
    graph_module.graph.lint()
    graph_module.recompile()


shutil.rmtree(ARTIFACT_DIR, ignore_errors=True)
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)

x = torch.randn(1, 3, 256, 448)
model = TinyNearestResizeProbe().eval()

exported_program = torch.export.export(model, (x,), strict=True)
graph_module = exported_program.module()
strip_unused_guard_nodes(graph_module)

quantizer = VgfQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT+int16"))
qconfig = get_symmetric_quantization_config(
    is_per_channel=True,
    is_qat=False,
    is_dynamic=False,
    act_qmin=-127,
    act_qmax=127,
    weight_qmin=-127,
    weight_qmax=127,
)
quantizer.set_global(qconfig).set_io(qconfig)

quantized_graph = quantizer.quantize_with_submodules(
    graph_module,
    calibration_samples=[(x,)],
    is_qat=False,
)
quantized_exported = torch.export.export(quantized_graph, (x,))

compile_spec = TosaCompileSpec(
    TosaSpecification.create_from_string("TOSA-1.0+INT+int16")
).dump_intermediate_artifacts_to(str(ARTIFACT_DIR))

partitioner = TOSAPartitioner(compile_spec)
to_edge_transform_and_lower(
    quantized_exported,
    partitioner=[partitioner],
    compile_config=EdgeCompileConfig(_check_ir_validity=False),
)

If this follows the same Arm resize-parameter generation as the bilinear case, the emitted resize parameters would be expected to include:

mode = NEAREST
scales = [2, 32, 2, 32]
offset = [15, 15]
border = [-15, -15]

That would be invalid TOSA because:

scale_d >= 16 * scale_n
32 >= 16 * 2

Root cause

The Arm backend appears to validate only the specific bilinear 1/16 downscale case that was reported in #19069.

But the TOSA RESIZE legality checks are broader:

  • the lower-bound downscale rule applies to the RESIZE operation generally
  • MAX_SCALE applies to the scale ratio
  • offset and border ranges are constrained
  • the computed output shape must match the requested output shape
  • dimensions and scale terms have additional bounds

So the backend should ideally have one shared TOSA RESIZE legality check that is used by both:

  1. operator support checks, before delegation;
  2. fake TOSA dialect op validation / lowering-time validation.

Expected behavior

ExecuTorch Arm should not emit an invalid TOSA RESIZE.

For any RESIZE mode, if the computed TOSA parameters violate TOSA validity rules, the backend should do one of:

  • reject / undelegate the resize with a clear diagnostic;
  • legalize it into a valid sequence;
  • or lower it through an equivalent valid implementation.

At minimum, it should not emit a .tosa flatbuffer that a TOSA-compliant validator rejects.

Actual behavior

The current code appears to reject bilinear RESIZE downscales at or below the 1/16 boundary, while nearest resize appears to be treated as always supported. The fake TOSA RESIZE validator also appears to apply the 1/16 check only to bilinear.

As a result, nearest exact 1/16 downscale appears able to lower to invalid TOSA.

Other TOSA RESIZE validity rules, such as MAX_SCALE, also appear not to be checked consistently in these support/fake-op validation paths.

Suggested fix

Add a shared helper that validates TOSA RESIZE parameters independently of the source PyTorch op and independently of resize mode, e.g.:

def validate_tosa_resize_parameters(
    *,
    input_hw,
    output_hw,
    scale,
    offset,
    border,
    resize_mode,
    tosa_spec,
):
    ...

This helper should enforce the relevant TOSA RESIZE constraints, including:

scale_y_n / scale_y_d <= MAX_SCALE
scale_x_n / scale_x_d <= MAX_SCALE
scale_y_n > 0
scale_y_d > 0
scale_x_n > 0
scale_x_d > 0
scale_y_n <= 1 << 11
scale_x_n <= 1 << 11
scale_y_d < 16 * scale_y_n
scale_x_d < 16 * scale_x_n
offset and border ranges
output shape consistency
dimension bounds

Then call that helper from:

  • UpsampleNearest2dSupported
  • UpsampleBilinear2dSupported
  • fake TOSA RESIZE validation
  • any other Arm path that can create a TOSA RESIZE

Versions

Observed / investigated locally with:

executorch==1.2.0.dev20260305+cpu
torch==2.10.0
torchao==0.15.0
Windows 11

Also inspected current upstream main code paths in:

backends/arm/operator_support/upsample_support.py
backends/arm/tosa/dialect/ops/resize.py

cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: armIssues related to arm backendpartner: armFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions