From fae7030e218c5667fb4b1e4bf9a3fe333af42831 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Fri, 9 Jan 2026 17:47:26 +0100 Subject: [PATCH 1/5] Cortex-M backend: Add support for planned (depthwise) conv scratch size Add scratch tensors to the operator signatures, which are then assigned exir.memory.alloc. These allocs are automatically memory planned by ExecuTorch. Introduce `required_cmsis_buffer_size`which computes the buffer size from node properties + the Cortex-M configuration. The function uses functions registered by target in backends/cortex_m/passes/scratch_buffer_sizes.py This is used to set the size of the allocs in ConvertToCortexMPass Finally, modify the kernels to use the new scratch tensor instead of allocating temporary memory. Add a new macro CORTEX_M_ENABLE_ASSERT to do a safety check that the aot computed buffer size is equal to the buffer size computed at runtime. Use this when testing. Signed-off-by: Erik Lundell Change-Id: Ia7ec8eda87833888a0639b480e531fd17818298a --- backends/arm/scripts/build_executorch.sh | 8 + backends/cortex_m/CMakeLists.txt | 8 + backends/cortex_m/ops/op_quantized_conv2d.cpp | 34 ++-- .../ops/op_quantized_depthwise_conv2d.cpp | 31 +-- backends/cortex_m/ops/operators.py | 12 +- backends/cortex_m/ops/operators.yaml | 5 +- backends/cortex_m/passes/__init__.py | 1 + .../passes/convert_to_cortex_m_pass.py | 42 ++++- .../cortex_m/passes/scratch_buffer_sizes.py | 176 ++++++++++++++++++ backends/cortex_m/test/build_test_runner.sh | 2 +- 10 files changed, 279 insertions(+), 40 deletions(-) create mode 100644 backends/cortex_m/passes/scratch_buffer_sizes.py diff --git a/backends/arm/scripts/build_executorch.sh b/backends/arm/scripts/build_executorch.sh index cf7e327b9ce..3d672a041f3 100755 --- a/backends/arm/scripts/build_executorch.sh +++ b/backends/arm/scripts/build_executorch.sh @@ -7,6 +7,7 @@ # Optional parameter: # --build_type= "Release" | "Debug" | "RelWithDebInfo" | "UndefinedSanitizer" | "AddressSanitizer" # --etdump build with devtools-etdump support +# --cmake-args= Additional arguments passed to cmake configure set -eu @@ -24,6 +25,7 @@ build_type="Release" build_devtools=OFF build_with_etdump=OFF is_linux_musl=0 +extra_cmake_args=() help() { echo "Usage: $(basename $0) [options]" @@ -32,6 +34,7 @@ help() { echo " --build_type= Build with Release, Debug, RelWithDebInfo, UndefinedSanitizer or AddressSanitizer, default is ${build_type}" echo " --devtools Build Devtools libs" echo " --etdump Adds Devtools etdump support to track timing, etdump area will be base64 encoded in the log" + echo " --cmake-args= Additional arguments passed to cmake configure" echo " --toolchain= Toolchain can be specified (arm-none-eabi-gcc, arm-zephyr-eabi-gcc, aarch64-linux-musl-gcc). Default: ${toolchain}" exit 0 } @@ -43,6 +46,10 @@ for arg in "$@"; do --build_type=*) build_type="${arg#*=}";; --devtools) build_devtools=ON ;; --etdump) build_with_etdump=ON ;; + --cmake-args=*) + # shellcheck disable=SC2206 + extra_cmake_args=(${arg#*=}) + ;; --toolchain=*) toolchain="${arg#*=}";; *) ;; @@ -85,6 +92,7 @@ cmake_args=( -DCMAKE_BUILD_TYPE=${build_type} -DEXECUTORCH_BUILD_DEVTOOLS=${build_devtools} -DEXECUTORCH_BUILD_ARM_ETDUMP=${build_with_etdump} + "${extra_cmake_args[@]}" ) if [[ ${is_linux_musl} -eq 1 ]]; then diff --git a/backends/cortex_m/CMakeLists.txt b/backends/cortex_m/CMakeLists.txt index 8c8255b7b1b..75dc0b49e27 100644 --- a/backends/cortex_m/CMakeLists.txt +++ b/backends/cortex_m/CMakeLists.txt @@ -30,6 +30,10 @@ set(CMSIS_NN_LOCAL_PATH "" CACHE PATH "Path to existing local CMSIS-NN installation" ) +option(CORTEX_M_ENABLE_ASSERTS + "Enable additional Cortex-M runtime assertions and validation checks" + OFF +) # Try to find existing / local CMSIS-NN installation. This is useful for # debugging and testing with local changes. This is not common, as the CMSIS-NN @@ -87,6 +91,10 @@ target_link_libraries( PRIVATE executorch PRIVATE kernels_util_all_deps ) +target_compile_definitions( + cortex_m_kernels + PRIVATE $<$:CORTEX_M_ENABLE_ASSERTS> +) # Include directories for cortex_m_kernels target_include_directories( diff --git a/backends/cortex_m/ops/op_quantized_conv2d.cpp b/backends/cortex_m/ops/op_quantized_conv2d.cpp index 7d4433690f6..b94c49b1d74 100644 --- a/backends/cortex_m/ops/op_quantized_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_conv2d.cpp @@ -112,6 +112,7 @@ Tensor& quantized_conv2d_out( const Tensor& requantize_shifts, const int64_t activation_min, const int64_t activation_max, + const Tensor& scratch, Tensor& out) { if (!validate_conv2d_arguments( context, @@ -182,31 +183,30 @@ Tensor& quantized_conv2d_out( cmsis_nn_context cmsis_context; cmsis_context.buf = nullptr; - cmsis_context.size = 0; + cmsis_context.size = scratch.nbytes(); + if (cmsis_context.size > 0) { + cmsis_context.buf = scratch.mutable_data_ptr(); + } - const int32_t buffer_bytes = arm_convolve_wrapper_s8_get_buffer_size( +#ifdef CORTEX_M_ENABLE_ASSERTS + const int32_t runtime_buffer_bytes = arm_convolve_wrapper_s8_get_buffer_size( &conv_params, &input_dims, &filter_dims, &output_dims); - if (buffer_bytes < 0) { + if (runtime_buffer_bytes < 0) { ET_LOG( Error, "quantized_conv2d_out: CMSIS-NN buffer size calculation failed"); context.fail(Error::Internal); return out; } - if (buffer_bytes > 0) { - auto buffer_or_error = - context.allocate_temp(buffer_bytes, kCortexMMveAlignment); - if (!buffer_or_error.ok()) { - ET_LOG( - Error, - "quantized_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)", - static_cast(buffer_bytes), - static_cast(buffer_or_error.error())); - context.fail(buffer_or_error.error()); - return out; - } - cmsis_context.buf = buffer_or_error.get(); - cmsis_context.size = buffer_bytes; + if (scratch.nbytes() != static_cast(runtime_buffer_bytes)) { + ET_LOG( + Error, + "quantized_conv2d_out: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(scratch.nbytes()), + static_cast(runtime_buffer_bytes)); + context.fail(Error::Internal); + return out; } +#endif const arm_cmsis_nn_status status = arm_convolve_wrapper_s8( &cmsis_context, diff --git a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp index 8dec61e0af1..7ed761bad9d 100644 --- a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp @@ -150,6 +150,7 @@ Tensor& quantized_depthwise_conv2d_out( const Tensor& requantize_shifts, const int64_t activation_min, const int64_t activation_max, + const Tensor& scratch, Tensor& out) { if (!validate_depthwise_conv2d_arguments( context, @@ -220,32 +221,32 @@ Tensor& quantized_depthwise_conv2d_out( cmsis_nn_context cmsis_context; cmsis_context.buf = nullptr; - cmsis_context.size = 0; + cmsis_context.size = scratch.nbytes(); + if (cmsis_context.size > 0) { + cmsis_context.buf = scratch.mutable_data_ptr(); + } - const int32_t buffer_bytes = arm_depthwise_conv_wrapper_s8_get_buffer_size( - &dw_conv_params, &input_dims, &filter_dims, &output_dims); - if (buffer_bytes < 0) { +#ifdef CORTEX_M_ENABLE_ASSERTS + const int32_t runtime_buffer_bytes = + arm_depthwise_conv_wrapper_s8_get_buffer_size( + &dw_conv_params, &input_dims, &filter_dims, &output_dims); + if (runtime_buffer_bytes < 0) { ET_LOG( Error, "quantized_depthwise_conv2d_out: CMSIS-NN buffer size calculation failed"); context.fail(Error::Internal); return out; } - - auto buffer_or_error = context.allocate_temp( - static_cast(buffer_bytes), kCortexMMveAlignment); - if (!buffer_or_error.ok()) { + if (scratch.nbytes() != static_cast(runtime_buffer_bytes)) { ET_LOG( Error, - "quantized_depthwise_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)", - static_cast(buffer_bytes), - static_cast(buffer_or_error.error())); - context.fail(buffer_or_error.error()); + "quantized_depthwise_conv2d_out: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(scratch.nbytes()), + static_cast(runtime_buffer_bytes)); + context.fail(Error::Internal); return out; } - cmsis_context.buf = buffer_or_error.get(); - cmsis_context.size = buffer_bytes; - +#endif const arm_cmsis_nn_status status = arm_depthwise_conv_wrapper_s8( &cmsis_context, &dw_conv_params, diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 2c35ed8730b..ad9e3b4113f 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -638,7 +638,8 @@ def pad_impl( "Tensor requantize_multipliers, " "Tensor requantize_shifts, " "int activation_min, " - "int activation_max" + "int activation_max, " + "Tensor scratch" ") -> Tensor" ) @@ -657,6 +658,7 @@ def pad_impl( "Tensor requantize_shifts, " "int activation_min, " "int activation_max, " + "Tensor scratch, " "*, Tensor(a!) out" ") -> Tensor(a!)" ) @@ -733,6 +735,7 @@ def quantized_conv2d_meta( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: stride_vals = list(stride) padding_vals = list(padding) @@ -762,6 +765,7 @@ def quantized_conv2d_impl( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: if input.dim() != 4 or weight.dim() != 4: raise RuntimeError("quantized_conv2d expects 4D input and weight tensors") @@ -830,7 +834,8 @@ def quantized_conv2d_impl( "Tensor requantize_multipliers, " "Tensor requantize_shifts, " "int activation_min, " - "int activation_max" + "int activation_max, " + "Tensor scratch" ") -> Tensor" ) @@ -850,6 +855,7 @@ def quantized_conv2d_impl( "Tensor requantize_shifts, " "int activation_min, " "int activation_max, " + "Tensor scratch, " "*, Tensor(a!) out" ") -> Tensor(a!)" ) @@ -870,6 +876,7 @@ def quantized_depthwise_conv2d_meta( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: stride_vals = list(stride) padding_vals = list(padding) @@ -900,6 +907,7 @@ def quantized_depthwise_conv2d_impl( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: if input.dim() != 4 or weight.dim() != 4: raise RuntimeError( diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index e0ebbfab868..c1fbbf01e9d 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -65,13 +65,14 @@ - arg_meta: null kernel_name: cortex_m::pad_out -- func: cortex_m::quantized_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, Tensor scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null kernel_name: cortex_m::quantized_conv2d_out -- func: cortex_m::quantized_depthwise_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int depth_multiplier, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) + +- func: cortex_m::quantized_depthwise_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int depth_multiplier, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, Tensor scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null diff --git a/backends/cortex_m/passes/__init__.py b/backends/cortex_m/passes/__init__.py index 92179ec6654..c379461949f 100644 --- a/backends/cortex_m/passes/__init__.py +++ b/backends/cortex_m/passes/__init__.py @@ -33,6 +33,7 @@ def _ensure_cortex_m_dependencies() -> None: _ensure_cortex_m_dependencies() +from .cortex_m_pass import CortexMPass # noqa # usort: skip from .activation_fusion_pass import ActivationFusionPass # noqa from .clamp_hardswish_pass import ClampHardswishPass # noqa from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index 418f6cd63ff..cb514a8f0ca 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -6,25 +6,29 @@ # LICENSE file in the root directory of this source tree. import executorch.backends.cortex_m.ops.operators # noqa +import executorch.exir as exir import torch import torch.fx from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor + +from executorch.backends.cortex_m.passes import CortexMPass from executorch.backends.cortex_m.passes.passes_utils import quantize_multiplier_aot +from executorch.backends.cortex_m.passes.scratch_buffer_sizes import ( + required_cmsis_nn_buffer_sizes, +) from executorch.backends.transforms.utils import ( create_constant_placeholder, get_param_tensor, is_param_node, ) - -from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass from executorch.exir.dialects._ops import ops as exir_ops from torch.export.graph_signature import InputKind from torch.fx.passes.infra.pass_manager import PassResult -class ConvertToCortexMPass(XNNPACKPass): +class ConvertToCortexMPass(CortexMPass): """ Cortex-M backend pass for replacing supported quantized kernels with Cortex-M accelerated kernels. @@ -238,6 +242,15 @@ def _get_convolution_replacement(self, node): torch.tensor(quantized_shifts, dtype=torch.int32), ) + with node.graph.inserting_before(node): + # Args of alloc are overwritten with planned size at a later point. + uninitialized_args = (((0,), torch.uint8),) + scratch = node.graph.call_function( + exir.memory.alloc, + args=uninitialized_args, + kwargs={}, + ) + if use_depthwise_conv: # Compute depth_multiplier for depthwise convolution # For depthwise: output_channels = input_channels * depth_multiplier @@ -263,6 +276,7 @@ def _get_convolution_replacement(self, node): quantized_shift_tensor, output_qmin, output_qmax, + scratch, ) return exir_ops.edge.cortex_m.quantized_depthwise_conv2d.default, new_args else: @@ -280,9 +294,30 @@ def _get_convolution_replacement(self, node): quantized_shift_tensor, output_qmin, output_qmax, + scratch, ) return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args + def _set_scratch_buffer_size(self, node: torch.fx.Node) -> None: + scratch_buffer_sizes = required_cmsis_nn_buffer_sizes( + node, self.target_config.backend + ) + if scratch_buffer_sizes is None: + return + + for i, scratch_buffer_size in enumerate(reversed(scratch_buffer_sizes)): + scratch_arg = node.args[-(i + 1)] + if ( + not isinstance(scratch_arg, torch.fx.Node) + or scratch_arg.target != exir.memory.alloc + ): + raise RuntimeError( + f"Expected scratch alloc node as final argument(s) for {node.target}, got {scratch_arg}." + ) + + # buffer size is given in bytes, always use uint8 as dtype. + scratch_arg.args = (((scratch_buffer_size,), torch.uint8),) + def _get_transpose_conv2d_replacement(self, node): """ Transform aten.convolution with transposed=True to cortex_m.quantized_transpose_conv2d @@ -459,6 +494,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: args=args, kwargs={}, ) + self._set_scratch_buffer_size(cortex_m_op) node.replace_all_uses_with(cortex_m_op) graph_module.graph.erase_node(node) diff --git a/backends/cortex_m/passes/scratch_buffer_sizes.py b/backends/cortex_m/passes/scratch_buffer_sizes.py new file mode 100644 index 00000000000..b16ad6ae661 --- /dev/null +++ b/backends/cortex_m/passes/scratch_buffer_sizes.py @@ -0,0 +1,176 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Callable +from typing import cast + +import cmsis_nn # type: ignore[import-not-found, import-untyped] +import executorch.backends.cortex_m.ops.operators # noqa + +import torch +import torch.fx + +from executorch.exir.dialects._ops import ops as exir_ops + +BufferSizeFunction = Callable[[cmsis_nn.Backend, torch.fx.Node], list[int]] + + +def _tensor_from_node(node: torch.fx.Node) -> torch.Tensor: + if "val" in node.meta: + return node.meta["val"] + elif node.op == "call_function": + args = ( + _tensor_from_node(arg) if isinstance(arg, torch.fx.Node) else arg + for arg in node.args + ) + return node.target(*args, **node.kwargs) # type: ignore[operator] + else: + raise RuntimeError("Encountered non-call_function without 'val' meta.") + + +def _shape_from_node(node: torch.fx.Node) -> torch.Size: + return _tensor_from_node(node).shape + + +def _get_common_conv_buffer_size_inputs( + *, + conv_node: torch.fx.Node, +) -> tuple[ + list[int], + list[int], + list[int], + list[int], + list[int], + list[int], +]: + x = cast(torch.fx.Node, conv_node.args[0]) + weight = cast(torch.fx.Node, conv_node.args[1]) + stride = cast(list[int], conv_node.args[3]) + padding = cast(list[int], conv_node.args[4]) + dilation = cast(list[int], conv_node.args[5]) + + # Input is NCHW (PyTorch); CMSIS-NN wants NHWC dims. + n, c_in, height, width = _shape_from_node(x) + + weight_shape = _shape_from_node(weight) + + # Output is NCHW; convert to NHWC dims. + out_n, out_c, out_h, out_w = _shape_from_node(conv_node) + + input_nhwc = [n, height, width, c_in] + output_nhwc = [out_n, out_h, out_w, out_c] + stride_hw = [int(stride[0]), int(stride[1])] + padding_hw = [int(padding[0]), int(padding[1])] + dilation_hw = [int(dilation[0]), int(dilation[1])] + + return ( + input_nhwc, + list(weight_shape), + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) + + +def cmsis_nn_conv_buffer_size( + backend: cmsis_nn.Backend, + conv_node: torch.fx.Node, +) -> list[int]: + ( + input_nhwc, + weight_shape, + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) = _get_common_conv_buffer_size_inputs(conv_node=conv_node) + input_offset = cast(int, conv_node.args[6]) + output_offset = cast(int, conv_node.args[7]) + output_qmin = cast(int, conv_node.args[10]) + output_qmax = cast(int, conv_node.args[11]) + + # Weight is in OHWI layout after conversion. + c_out, kernel_h, kernel_w, c_in = weight_shape + filter_nhwc = [c_out, kernel_h, kernel_w, c_in] + + return [ + int( + cmsis_nn.convolve_wrapper_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + output_nhwc=output_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ) + ] + + +def cmsis_nn_depthwise_conv_buffer_size( + backend: cmsis_nn.Backend, + conv_node: torch.fx.Node, +) -> list[int]: + ( + input_nhwc, + weight_shape, + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) = _get_common_conv_buffer_size_inputs(conv_node=conv_node) + depth_multiplier = cast(int, conv_node.args[6]) + input_offset = cast(int, conv_node.args[7]) + output_offset = cast(int, conv_node.args[8]) + output_qmin = cast(int, conv_node.args[11]) + output_qmax = cast(int, conv_node.args[12]) + + # Weight is in IHWO layout after conversion. + _, kernel_h, kernel_w, c_out = weight_shape + filter_nhwc = [c_out, kernel_h, kernel_w, 1] + + return [ + int( + cmsis_nn.depthwise_conv_wrapper_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + output_nhwc=output_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + ch_mult=depth_multiplier, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ) + ] + + +_target_to_buffer_sizes_registry = { + exir_ops.edge.cortex_m.quantized_conv2d.default: cmsis_nn_conv_buffer_size, + exir_ops.edge.cortex_m.quantized_depthwise_conv2d.default: cmsis_nn_depthwise_conv_buffer_size, +} + + +def required_cmsis_nn_buffer_sizes( + node: torch.fx.Node, backend: cmsis_nn.Backend +) -> list[int] | None: + """Returns a sequence of scratch buffer sizes required by node, in bytes.""" + if node.target not in _target_to_buffer_sizes_registry: + return None + + buffer_size_function = _target_to_buffer_sizes_registry[node.target] + return buffer_size_function(backend, node) diff --git a/backends/cortex_m/test/build_test_runner.sh b/backends/cortex_m/test/build_test_runner.sh index 2505f83c9da..fff0fe79271 100755 --- a/backends/cortex_m/test/build_test_runner.sh +++ b/backends/cortex_m/test/build_test_runner.sh @@ -12,7 +12,7 @@ set -eu script_dir=$(realpath "$(dirname "${BASH_SOURCE[0]}")") et_root_dir=$(realpath "${script_dir}/../../..") build_executorch="${et_root_dir}/backends/arm/scripts/build_executorch.sh" -${build_executorch} --devtools +${build_executorch} --devtools --cmake-args="-DCORTEX_M_ENABLE_ASSERTS=ON" # Build executor runner with selected aten ops and semi hosting build_dir="${et_root_dir}/arm_test" From 936ed4bbf6a781d0d7569d7df64c0df1b75c7e72 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Wed, 15 Apr 2026 15:40:05 +0200 Subject: [PATCH 2/5] Cortex-M backend: Add support for planned bmm scratch size. Follow the plan from previous buffer planning work. Signed-off-by: Erik Lundell Change-Id: I4bf3ca1cc421421b61903cba24856d0fd635d64a --- .../ops/op_quantized_batch_matmul.cpp | 35 ++++++++++--------- backends/cortex_m/ops/operators.py | 6 +++- backends/cortex_m/ops/operators.yaml | 2 +- .../passes/convert_to_cortex_m_pass.py | 15 ++++++-- .../cortex_m/passes/scratch_buffer_sizes.py | 22 +++++++++++- 5 files changed, 58 insertions(+), 22 deletions(-) diff --git a/backends/cortex_m/ops/op_quantized_batch_matmul.cpp b/backends/cortex_m/ops/op_quantized_batch_matmul.cpp index e6bc5a949ce..0aea882b107 100644 --- a/backends/cortex_m/ops/op_quantized_batch_matmul.cpp +++ b/backends/cortex_m/ops/op_quantized_batch_matmul.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2026 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -71,6 +72,7 @@ Tensor& quantized_batch_matmul_out( int64_t output_offset, int64_t output_multiplier, int64_t output_shift, + const Tensor& scratch, Tensor& out) { if (!validate_batch_matmul_arguments(context, lhs, rhs_transposed, out)) { return out; @@ -100,25 +102,26 @@ Tensor& quantized_batch_matmul_out( quant_params.multiplier = static_cast(output_multiplier); quant_params.shift = static_cast(output_shift); - const int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&out_dims); - cmsis_nn_context ctx; ctx.buf = nullptr; - ctx.size = 0; - - if (buf_size > 0) { - auto buffer_or_error = context.allocate_temp(buf_size); - if (!buffer_or_error.ok()) { - ET_LOG( - Error, - "quantized_batch_matmul: failed to allocate scratch buffer (%d bytes)", - buf_size); - context.fail(buffer_or_error.error()); - return out; - } - ctx.buf = buffer_or_error.get(); - ctx.size = buf_size; + ctx.size = scratch.nbytes(); + if (ctx.size > 0) { + ctx.buf = scratch.mutable_data_ptr(); + } + +#ifdef CORTEX_M_ENABLE_ASSERTS + const int32_t runtime_buffer_bytes = + arm_fully_connected_s8_get_buffer_size(&out_dims); + if (ctx.size != static_cast(runtime_buffer_bytes)) { + ET_LOG( + Error, + "quantized_batch_matmul: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(ctx.size), + runtime_buffer_bytes); + context.fail(Error::Internal); + return out; } +#endif const arm_cmsis_nn_status status = arm_batch_matmul_s8( &ctx, diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index ad9e3b4113f..14bb6ae9fdd 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -271,13 +271,15 @@ def quantized_mul_impl( "quantized_batch_matmul(" "Tensor lhs, int lhs_zero_point, " "Tensor rhs_transposed, int rhs_zero_point, " - "int output_zero_point, int output_multiplier, int output_shift) -> Tensor" + "int output_zero_point, int output_multiplier, int output_shift, " + "Tensor scratch) -> Tensor" ) lib.define( "quantized_batch_matmul.out(" "Tensor lhs, int lhs_zero_point, " "Tensor rhs_transposed, int rhs_zero_point, " "int output_zero_point, int output_multiplier, int output_shift, " + "Tensor scratch, " "*, Tensor(a!) out) -> Tensor(a!)" ) @@ -291,6 +293,7 @@ def quantized_batch_matmul_meta( output_zero_point: int, output_multiplier: int, output_shift: int, + scratch: torch.Tensor, ) -> torch.Tensor: batch, lhs_rows, inner = lhs.shape batch_rhs, rhs_cols, inner_rhs = rhs_transposed.shape @@ -307,6 +310,7 @@ def quantized_batch_matmul_impl( output_zero_point: int, output_multiplier: int, output_shift: int, + scratch: torch.Tensor, ) -> torch.Tensor: # Offsets are negated zero points (CMSIS-NN convention) lhs_fp = lhs.to(torch.float32) + float(lhs_zero_point) diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index c1fbbf01e9d..d8395e50ece 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -95,7 +95,7 @@ - arg_meta: null kernel_name: cortex_m::quantized_max_pool2d_out -- func: cortex_m::quantized_batch_matmul.out(Tensor lhs, int lhs_zero_point, Tensor rhs_transposed, int rhs_zero_point, int output_zero_point, int output_multiplier, int output_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_batch_matmul.out(Tensor lhs, int lhs_zero_point, Tensor rhs_transposed, int rhs_zero_point, int output_zero_point, int output_multiplier, int output_shift, Tensor scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index cb514a8f0ca..83ff2521e2f 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -27,6 +27,9 @@ from torch.export.graph_signature import InputKind from torch.fx.passes.infra.pass_manager import PassResult +UNINITIALIZED_ALLOC_ARGS = (((0,), torch.uint8),) +""" Args of alloc are overwritten with planned size at a later point.""" + class ConvertToCortexMPass(CortexMPass): """ @@ -243,11 +246,9 @@ def _get_convolution_replacement(self, node): ) with node.graph.inserting_before(node): - # Args of alloc are overwritten with planned size at a later point. - uninitialized_args = (((0,), torch.uint8),) scratch = node.graph.call_function( exir.memory.alloc, - args=uninitialized_args, + args=UNINITIALIZED_ALLOC_ARGS, kwargs={}, ) @@ -450,6 +451,13 @@ def _get_bmm_replacement(self, node): args=(rhs_node, [0, 2, 1]), ) + with node.graph.inserting_before(node): + scratch = node.graph.call_function( + exir.memory.alloc, + args=UNINITIALIZED_ALLOC_ARGS, + kwargs={}, + ) + args = ( lhs_node, -lhs_zp, @@ -458,6 +466,7 @@ def _get_bmm_replacement(self, node): output_zp, output_mult, output_shift, + scratch, ) return exir_ops.edge.cortex_m.quantized_batch_matmul.default, args diff --git a/backends/cortex_m/passes/scratch_buffer_sizes.py b/backends/cortex_m/passes/scratch_buffer_sizes.py index b16ad6ae661..28c7cafd041 100644 --- a/backends/cortex_m/passes/scratch_buffer_sizes.py +++ b/backends/cortex_m/passes/scratch_buffer_sizes.py @@ -35,7 +35,6 @@ def _shape_from_node(node: torch.fx.Node) -> torch.Size: def _get_common_conv_buffer_size_inputs( - *, conv_node: torch.fx.Node, ) -> tuple[ list[int], @@ -159,9 +158,30 @@ def cmsis_nn_depthwise_conv_buffer_size( ] +def cmsis_nn_batch_matmul_buffer_size( + backend: cmsis_nn.Backend, + matmul_node: torch.fx.Node, +) -> list[int]: + rhs_transposed = cast(torch.fx.Node, matmul_node.args[2]) + rhs_shape = _shape_from_node(rhs_transposed) + + _, rhs_cols, inner = rhs_shape + + return [ + int( + cmsis_nn.fully_connected_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + filter_nhwc=[inner, -1, -1, rhs_cols], # H and W values are unused. + ) + ) + ] + + _target_to_buffer_sizes_registry = { exir_ops.edge.cortex_m.quantized_conv2d.default: cmsis_nn_conv_buffer_size, exir_ops.edge.cortex_m.quantized_depthwise_conv2d.default: cmsis_nn_depthwise_conv_buffer_size, + exir_ops.edge.cortex_m.quantized_batch_matmul.default: cmsis_nn_batch_matmul_buffer_size, } From 699e7312673949208244dd7b79094e5a2ea3c6cf Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Wed, 15 Apr 2026 16:47:26 +0200 Subject: [PATCH 3/5] Cortex-M backend: Add support for planned transposed conv scratch size. We can now reduce the memory size to 0 when building the cortex_m test runner. Signed-off-by: Erik Lundell Change-Id: Ieb1292c2db4651cd1f0756aa9d43ecedd5e262e5 --- .../ops/op_quantized_transpose_conv2d.cpp | 44 +++++----- backends/cortex_m/ops/operators.py | 10 ++- backends/cortex_m/ops/operators.yaml | 2 +- .../passes/convert_to_cortex_m_pass.py | 14 ++++ .../cortex_m/passes/scratch_buffer_sizes.py | 82 +++++++++++++++++-- backends/cortex_m/test/build_test_runner.sh | 2 +- 6 files changed, 124 insertions(+), 30 deletions(-) diff --git a/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp b/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp index e3f6135c7b9..ebb713531b8 100644 --- a/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2026 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -97,6 +98,8 @@ Tensor& quantized_transpose_conv2d_out( const Tensor& requantize_shifts, const int64_t activation_min, const int64_t activation_max, + const Tensor& scratch, + const Tensor& output_scratch, Tensor& out) { if (!validate_transpose_conv2d_arguments( context, @@ -179,44 +182,43 @@ Tensor& quantized_transpose_conv2d_out( cmsis_nn_context cmsis_context; cmsis_context.buf = nullptr; - cmsis_context.size = 0; + cmsis_context.size = scratch.nbytes(); + if (cmsis_context.size > 0) { + cmsis_context.buf = scratch.mutable_data_ptr(); + } cmsis_nn_context output_context; output_context.buf = nullptr; - output_context.size = 0; - + output_context.size = output_scratch.nbytes(); + if (output_context.size > 0) { + output_context.buf = output_scratch.mutable_data_ptr(); + } +#ifdef CORTEX_M_ENABLE_ASSERTS const int32_t buffer_bytes = arm_transpose_conv_s8_get_buffer_size( &transpose_conv_params, &input_dims, &filter_dims, &output_dims); - auto buffer_or_error = context.allocate_temp( - static_cast(buffer_bytes), kCortexMMveAlignment); - if (!buffer_or_error.ok()) { + if (scratch.nbytes() != static_cast(buffer_bytes)) { ET_LOG( Error, - "quantized_transpose_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)", - buffer_bytes, - static_cast(buffer_or_error.error())); - context.fail(buffer_or_error.error()); + "quantized_transpose_conv2d_out: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(scratch.nbytes()), + buffer_bytes); + context.fail(Error::Internal); return out; } - cmsis_context.buf = buffer_or_error.get(); - cmsis_context.size = buffer_bytes; const int32_t output_buffer_bytes = arm_transpose_conv_s8_get_reverse_conv_buffer_size( &transpose_conv_params, &input_dims, &filter_dims); - auto output_buffer_or_error = context.allocate_temp( - static_cast(output_buffer_bytes), kCortexMMveAlignment); - if (!output_buffer_or_error.ok()) { + if (output_scratch.nbytes() != static_cast(output_buffer_bytes)) { ET_LOG( Error, - "quantized_transpose_conv2d_out: failed to allocate output scratch buffer (%d bytes, error %d)", - output_buffer_bytes, - static_cast(output_buffer_or_error.error())); - context.fail(output_buffer_or_error.error()); + "quantized_transpose_conv2d_out: output scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(output_scratch.nbytes()), + output_buffer_bytes); + context.fail(Error::Internal); return out; } - output_context.buf = output_buffer_or_error.get(); - output_context.size = output_buffer_bytes; +#endif const arm_cmsis_nn_status status = arm_transpose_conv_wrapper_s8( &cmsis_context, diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 14bb6ae9fdd..d4393bc7ada 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -985,7 +985,9 @@ def quantized_depthwise_conv2d_impl( "Tensor requantize_multipliers, " "Tensor requantize_shifts, " "int activation_min, " - "int activation_max" + "int activation_max, " + "Tensor scratch, " + "Tensor output_scratch" ") -> Tensor" ) @@ -1004,6 +1006,8 @@ def quantized_depthwise_conv2d_impl( "Tensor requantize_shifts, " "int activation_min, " "int activation_max, " + "Tensor scratch, " + "Tensor output_scratch, " "*, Tensor(a!) out) -> Tensor(a!)" ) @@ -1069,6 +1073,8 @@ def quantized_transpose_conv2d_meta( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, + output_scratch: torch.Tensor, ) -> torch.Tensor: stride_vals = list(stride) padding_vals = list(padding) @@ -1107,6 +1113,8 @@ def quantized_transpose_conv2d_impl( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, + output_scratch: torch.Tensor, ) -> torch.Tensor: """ Reference implementation of quantized transposed convolution. diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index d8395e50ece..8db109dea43 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -78,7 +78,7 @@ - arg_meta: null kernel_name: cortex_m::quantized_depthwise_conv2d_out -- func: cortex_m::quantized_transpose_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_transpose_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, Tensor scratch, Tensor output_scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index 83ff2521e2f..ec092bee62d 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -399,6 +399,18 @@ def _get_transpose_conv2d_replacement(self, node): torch.tensor(quantized_shifts, dtype=torch.int32), ) + with node.graph.inserting_before(node): + scratch = node.graph.call_function( + exir.memory.alloc, + args=UNINITIALIZED_ALLOC_ARGS, + kwargs={}, + ) + output_scratch = node.graph.call_function( + exir.memory.alloc, + args=UNINITIALIZED_ALLOC_ARGS, + kwargs={}, + ) + new_args = ( x, weight_nhwc, @@ -413,6 +425,8 @@ def _get_transpose_conv2d_replacement(self, node): quantized_shift_tensor, output_qmin, output_qmax, + scratch, + output_scratch, ) return exir_ops.edge.cortex_m.quantized_transpose_conv2d.default, new_args diff --git a/backends/cortex_m/passes/scratch_buffer_sizes.py b/backends/cortex_m/passes/scratch_buffer_sizes.py index 28c7cafd041..36f3f8bbc17 100644 --- a/backends/cortex_m/passes/scratch_buffer_sizes.py +++ b/backends/cortex_m/passes/scratch_buffer_sizes.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from collections.abc import Callable -from typing import cast +from typing import Any, cast import cmsis_nn # type: ignore[import-not-found, import-untyped] import executorch.backends.cortex_m.ops.operators # noqa @@ -36,6 +36,10 @@ def _shape_from_node(node: torch.fx.Node) -> torch.Size: def _get_common_conv_buffer_size_inputs( conv_node: torch.fx.Node, + *, + stride_arg_idx: int = 3, + padding_arg_idx: int = 4, + dilation_arg_idx: int = 5, ) -> tuple[ list[int], list[int], @@ -46,9 +50,9 @@ def _get_common_conv_buffer_size_inputs( ]: x = cast(torch.fx.Node, conv_node.args[0]) weight = cast(torch.fx.Node, conv_node.args[1]) - stride = cast(list[int], conv_node.args[3]) - padding = cast(list[int], conv_node.args[4]) - dilation = cast(list[int], conv_node.args[5]) + stride = cast(list[int], conv_node.args[stride_arg_idx]) + padding = cast(list[int], conv_node.args[padding_arg_idx]) + dilation = cast(list[int], conv_node.args[dilation_arg_idx]) # Input is NCHW (PyTorch); CMSIS-NN wants NHWC dims. n, c_in, height, width = _shape_from_node(x) @@ -178,17 +182,83 @@ def cmsis_nn_batch_matmul_buffer_size( ] -_target_to_buffer_sizes_registry = { +def cmsis_nn_transpose_conv_buffer_size( + backend: cmsis_nn.Backend, + conv_node: torch.fx.Node, +) -> list[int]: + ( + input_nhwc, + weight_shape, + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) = _get_common_conv_buffer_size_inputs( + conv_node=conv_node, + stride_arg_idx=3, + padding_arg_idx=4, + dilation_arg_idx=6, + ) + output_padding = cast(list[int], conv_node.args[5]) + input_offset = cast(int, conv_node.args[7]) + output_offset = cast(int, conv_node.args[8]) + output_qmin = cast(int, conv_node.args[11]) + output_qmax = cast(int, conv_node.args[12]) + c_out, kernel_h, kernel_w, kernel_c_in = weight_shape + filter_nhwc = [c_out, kernel_h, kernel_w, kernel_c_in] + padding_offsets_hw = [int(output_padding[0]), int(output_padding[1])] + + return [ + int( + cmsis_nn.transpose_conv_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + output_nhwc=output_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + padding_offsets_hw=padding_offsets_hw, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ), + int( + cmsis_nn.transpose_conv_reverse_conv_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + padding_offsets_hw=padding_offsets_hw, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ), + ] + + +_target_to_buffer_sizes_registry: dict[Any, BufferSizeFunction] = { exir_ops.edge.cortex_m.quantized_conv2d.default: cmsis_nn_conv_buffer_size, exir_ops.edge.cortex_m.quantized_depthwise_conv2d.default: cmsis_nn_depthwise_conv_buffer_size, exir_ops.edge.cortex_m.quantized_batch_matmul.default: cmsis_nn_batch_matmul_buffer_size, + exir_ops.edge.cortex_m.quantized_transpose_conv2d.default: cmsis_nn_transpose_conv_buffer_size, } def required_cmsis_nn_buffer_sizes( node: torch.fx.Node, backend: cmsis_nn.Backend ) -> list[int] | None: - """Returns a sequence of scratch buffer sizes required by node, in bytes.""" + """Returns a sequence of scratch buffer sizes required by node, in bytes. + If no function is registered to compute this for the target of the node, return None. + """ if node.target not in _target_to_buffer_sizes_registry: return None diff --git a/backends/cortex_m/test/build_test_runner.sh b/backends/cortex_m/test/build_test_runner.sh index fff0fe79271..4080c481c1f 100755 --- a/backends/cortex_m/test/build_test_runner.sh +++ b/backends/cortex_m/test/build_test_runner.sh @@ -32,4 +32,4 @@ aten::unsqueeze_copy.out,\ aten::select_copy.int_out,\ aten::amax.out" -${build_executor_runner} --pte=semihosting --bundleio --target=ethos-u55-128 --output="${build_root_test_dir}" --select_ops_list="${select_ops_list}" --extra_build_flags="-DET_ATOL=5.0 -DET_RTOL=1.0" +${build_executor_runner} --pte=semihosting --bundleio --target=ethos-u55-128 --output="${build_root_test_dir}" --select_ops_list="${select_ops_list}" --extra_build_flags="-DET_ATOL=5.0 -DET_RTOL=1.0 -DET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE=0" From 23c86c1e2c21b875d3314619071f2dc85ef30635 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 19 May 2026 08:40:17 +0200 Subject: [PATCH 4/5] Address review comments Mainly clarify the uninitialized/intialize alloc pattern. Signed-off-by: Erik Lundell Change-Id: I062a5048094129be6ed8e9f7eafc096f34132b2f --- .../passes/convert_to_cortex_m_pass.py | 49 +++++++++---------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index ec092bee62d..59c6cc0dd45 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -24,12 +24,12 @@ is_param_node, ) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.passes import make_alloc_node +from torch._subclasses.fake_tensor import FakeTensorMode + from torch.export.graph_signature import InputKind from torch.fx.passes.infra.pass_manager import PassResult -UNINITIALIZED_ALLOC_ARGS = (((0,), torch.uint8),) -""" Args of alloc are overwritten with planned size at a later point.""" - class ConvertToCortexMPass(CortexMPass): """ @@ -40,6 +40,15 @@ class ConvertToCortexMPass(CortexMPass): by call_operator. """ + def _uninitialized_scratch(self): + """Create an unitialized alloc node to be initialize at a later point.""" + with FakeTensorMode() as mode: + return make_alloc_node( + self.exported_program.graph_module, + mode.from_tensor(torch.empty(0)), + None, + ) + def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): """ Computes the precomputed kernel sum term (bias optional) @@ -246,11 +255,7 @@ def _get_convolution_replacement(self, node): ) with node.graph.inserting_before(node): - scratch = node.graph.call_function( - exir.memory.alloc, - args=UNINITIALIZED_ALLOC_ARGS, - kwargs={}, - ) + scratch = self._uninitialized_scratch() if use_depthwise_conv: # Compute depth_multiplier for depthwise convolution @@ -299,13 +304,19 @@ def _get_convolution_replacement(self, node): ) return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args - def _set_scratch_buffer_size(self, node: torch.fx.Node) -> None: + def _initialize_scratch_buffer_size(self, node: torch.fx.Node) -> None: + """For nodes with a registered buffer size function for node.target, set the buffer sizes + of the last n args, which should be exir.memory.alloc nodes. For nodes without a + registered function, do nothing. + """ + scratch_buffer_sizes = required_cmsis_nn_buffer_sizes( node, self.target_config.backend ) if scratch_buffer_sizes is None: return + # Assume that scratch_buffer_sizes are given from left to right in the call signature of node.target. for i, scratch_buffer_size in enumerate(reversed(scratch_buffer_sizes)): scratch_arg = node.args[-(i + 1)] if ( @@ -400,16 +411,8 @@ def _get_transpose_conv2d_replacement(self, node): ) with node.graph.inserting_before(node): - scratch = node.graph.call_function( - exir.memory.alloc, - args=UNINITIALIZED_ALLOC_ARGS, - kwargs={}, - ) - output_scratch = node.graph.call_function( - exir.memory.alloc, - args=UNINITIALIZED_ALLOC_ARGS, - kwargs={}, - ) + scratch = self._uninitialized_scratch() + output_scratch = self._uninitialized_scratch() new_args = ( x, @@ -466,11 +469,7 @@ def _get_bmm_replacement(self, node): ) with node.graph.inserting_before(node): - scratch = node.graph.call_function( - exir.memory.alloc, - args=UNINITIALIZED_ALLOC_ARGS, - kwargs={}, - ) + scratch = self._uninitialized_scratch() args = ( lhs_node, @@ -517,7 +516,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: args=args, kwargs={}, ) - self._set_scratch_buffer_size(cortex_m_op) + self._initialize_scratch_buffer_size(cortex_m_op) node.replace_all_uses_with(cortex_m_op) graph_module.graph.erase_node(node) From 264fc9e4cfc57e5299039dfe5dc9d47cfa0d7fbd Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 19 May 2026 09:06:16 +0200 Subject: [PATCH 5/5] Change helper function names. Signed-off-by: Erik Lundell Change-Id: I8da2906a5f4cc69d15d033d8e5d1113d8b4afc4e --- .../cortex_m/passes/convert_to_cortex_m_pass.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index 59c6cc0dd45..e61ddaf63bc 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -40,7 +40,7 @@ class ConvertToCortexMPass(CortexMPass): by call_operator. """ - def _uninitialized_scratch(self): + def _create_uninitialized_alloc_node(self): """Create an unitialized alloc node to be initialize at a later point.""" with FakeTensorMode() as mode: return make_alloc_node( @@ -255,7 +255,7 @@ def _get_convolution_replacement(self, node): ) with node.graph.inserting_before(node): - scratch = self._uninitialized_scratch() + scratch = self._create_uninitialized_alloc_node() if use_depthwise_conv: # Compute depth_multiplier for depthwise convolution @@ -304,7 +304,7 @@ def _get_convolution_replacement(self, node): ) return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args - def _initialize_scratch_buffer_size(self, node: torch.fx.Node) -> None: + def _initialize_alloc_node_size(self, node: torch.fx.Node) -> None: """For nodes with a registered buffer size function for node.target, set the buffer sizes of the last n args, which should be exir.memory.alloc nodes. For nodes without a registered function, do nothing. @@ -411,8 +411,8 @@ def _get_transpose_conv2d_replacement(self, node): ) with node.graph.inserting_before(node): - scratch = self._uninitialized_scratch() - output_scratch = self._uninitialized_scratch() + scratch = self._create_uninitialized_alloc_node() + output_scratch = self._create_uninitialized_alloc_node() new_args = ( x, @@ -469,7 +469,7 @@ def _get_bmm_replacement(self, node): ) with node.graph.inserting_before(node): - scratch = self._uninitialized_scratch() + scratch = self._create_uninitialized_alloc_node() args = ( lhs_node, @@ -516,7 +516,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: args=args, kwargs={}, ) - self._initialize_scratch_buffer_size(cortex_m_op) + self._initialize_alloc_node_size(cortex_m_op) node.replace_all_uses_with(cortex_m_op) graph_module.graph.erase_node(node)