From a6d99fc6ea0f2a3a3a814a05f6ea87b1c9e4c7b0 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Tue, 5 May 2026 08:27:27 +0200 Subject: [PATCH] Arm backend: Don't execute eagerly with sym-ints Passes that eagerly executes ops with symbolic shapes to trace new metadata will break the graph. https://github.com/pytorch/pytorch/issues/182940 tries to figure out why eager execution breaks the graph. Signed-off-by: Oscar Andersson Change-Id: I384eedde392ee76a015f30a5f164c5bdf7f94b7e --- backends/arm/_passes/arm_pass_utils.py | 13 +++++++ .../normalize_delegate_io_layout_pass.py | 9 ++--- backends/arm/_passes/rewrite_conv_pass.py | 35 ++++++++----------- backends/arm/_passes/rewrite_upsample.py | 8 ++--- 4 files changed, 37 insertions(+), 28 deletions(-) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 9176f761220..000f92135eb 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -413,3 +413,16 @@ def to_2tuple(value): if len(value) == 1: return (value[0], value[0]) return tuple(value) + + +def permute_fake_tensor_metadata( + fake_tensor: FakeTensor, permute_dims: tuple[int, ...] +) -> FakeTensor: + permuted_shape = tuple(fake_tensor.shape[dim] for dim in permute_dims) + meta_tensor = torch.empty( + permuted_shape, + dtype=fake_tensor.dtype, + device="meta", + requires_grad=fake_tensor.requires_grad, + ) + return FakeTensor(fake_tensor.fake_mode, meta_tensor, fake_tensor.fake_device) diff --git a/backends/arm/_passes/normalize_delegate_io_layout_pass.py b/backends/arm/_passes/normalize_delegate_io_layout_pass.py index d1b1d964b87..c55eec5c851 100644 --- a/backends/arm/_passes/normalize_delegate_io_layout_pass.py +++ b/backends/arm/_passes/normalize_delegate_io_layout_pass.py @@ -11,6 +11,7 @@ create_node, get_first_fake_tensor, is_param_node, + permute_fake_tensor_metadata, ) from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops @@ -63,8 +64,8 @@ def _normalize_input_layout(self, graph_module: torch.fx.GraphModule) -> bool: args=(node, list(transpose_perm)), from_node=node, ) - permute_node.meta["val"] = exir_ops.edge.aten.permute_copy.default( - node.meta["val"], list(transpose_perm) + permute_node.meta["val"] = permute_fake_tensor_metadata( + get_first_fake_tensor(node), transpose_perm ) users = [user for user in node.users if user != permute_node] @@ -91,8 +92,8 @@ def _rewrite_output_arg( args=(arg, list(dim_order)), from_node=arg, ) - permute_node.meta["val"] = exir_ops.edge.aten.permute_copy.default( - output_fake, list(dim_order) + permute_node.meta["val"] = permute_fake_tensor_metadata( + output_fake, dim_order ) return permute_node, True diff --git a/backends/arm/_passes/rewrite_conv_pass.py b/backends/arm/_passes/rewrite_conv_pass.py index ed4df2e43b6..a51f1ae0555 100644 --- a/backends/arm/_passes/rewrite_conv_pass.py +++ b/backends/arm/_passes/rewrite_conv_pass.py @@ -16,6 +16,7 @@ get_first_fake_tensor, get_param_tensor, is_persistent_buffer, + permute_fake_tensor_metadata, ) from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, @@ -421,9 +422,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 args=(x, list(pre_permute_dims)), from_node=node, ) - x.meta["val"] = exir_ops.edge.aten.permute_copy.default( - input_fake_tensor, list(pre_permute_dims) + input_tensor_for_tosa_fake = permute_fake_tensor_metadata( + input_fake_tensor, pre_permute_dims ) + x.meta["val"] = input_tensor_for_tosa_fake weight = self._rewrite_weight( graph_module, weight, @@ -431,7 +433,6 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 permute_dims=OHWI_ORDER, name_suffix="ohwi", ) - input_tensor_for_tosa_fake = input_fake_tensor.permute(pre_permute_dims) weight_fake_tensor = get_first_fake_tensor(weight) conv_args = ( x, @@ -471,9 +472,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 args=(x, list(pre_permute_dims)), from_node=node, ) - x.meta["val"] = exir_ops.edge.aten.permute_copy.default( - input_fake_tensor, list(pre_permute_dims) + input_tensor_for_tosa_fake = permute_fake_tensor_metadata( + input_fake_tensor, pre_permute_dims ) + x.meta["val"] = input_tensor_for_tosa_fake weight = self._rewrite_weight( graph_module, weight, @@ -481,9 +483,6 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 permute_dims=ODHWI_ORDER, name_suffix="odhwi", ) - input_tensor_for_tosa_fake = input_fake_tensor.permute( - pre_permute_dims - ) weight_fake_tensor = get_first_fake_tensor(weight) elif self._is_depthwise_conv2d(node): target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default @@ -496,9 +495,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 args=(x, list(pre_permute_dims)), from_node=node, ) - x.meta["val"] = exir_ops.edge.aten.permute_copy.default( - input_fake_tensor, list(pre_permute_dims) + input_tensor_for_tosa_fake = permute_fake_tensor_metadata( + input_fake_tensor, pre_permute_dims ) + x.meta["val"] = input_tensor_for_tosa_fake kh, kw = weight_shape[2], weight_shape[3] in_channels = input_fake_tensor.shape[1] m_length = weight_shape[0] // in_channels @@ -510,9 +510,6 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 name_suffix="hwicm", reshape_dims=(kh, kw, in_channels, m_length), ) - input_tensor_for_tosa_fake = input_fake_tensor.permute( - pre_permute_dims - ) weight_fake_tensor = get_first_fake_tensor(weight) else: target_op = exir_ops.backend.tosa.CONV2D.default @@ -525,9 +522,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 args=(x, list(pre_permute_dims)), from_node=node, ) - x.meta["val"] = exir_ops.edge.aten.permute_copy.default( - input_fake_tensor, list(pre_permute_dims) + input_tensor_for_tosa_fake = permute_fake_tensor_metadata( + input_fake_tensor, pre_permute_dims ) + x.meta["val"] = input_tensor_for_tosa_fake weight = self._rewrite_weight( graph_module, weight, @@ -535,9 +533,6 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 permute_dims=NHWC_ORDER, name_suffix="ohwi", ) - input_tensor_for_tosa_fake = input_fake_tensor.permute( - pre_permute_dims - ) weight_fake_tensor = get_first_fake_tensor(weight) conv_args = ( @@ -612,8 +607,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 TosaSpecialDtype.meta_key() ): node_replacement.meta[TosaSpecialDtype.meta_key()] = special_dtype - node_replacement.meta["val"] = exir_ops.edge.aten.permute_copy.default( - node_replacement_fake_tensor, list(post_permute_dims) + node_replacement.meta["val"] = permute_fake_tensor_metadata( + node_replacement_fake_tensor, post_permute_dims ) node.replace_all_uses_with(node_replacement) diff --git a/backends/arm/_passes/rewrite_upsample.py b/backends/arm/_passes/rewrite_upsample.py index 9f81f5cbbe5..68a088286fa 100644 --- a/backends/arm/_passes/rewrite_upsample.py +++ b/backends/arm/_passes/rewrite_upsample.py @@ -13,6 +13,7 @@ create_node, create_shape_node, get_first_fake_tensor, + permute_fake_tensor_metadata, ) from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.exir.dialects._ops import ops as exir_ops @@ -196,7 +197,7 @@ def call(self, graph_module): args=(x, list(self._NHWC_ORDER)), from_node=node, ) - pre_permute.meta["val"] = exir_ops.edge.aten.permute_copy.default( + pre_permute.meta["val"] = permute_fake_tensor_metadata( get_first_fake_tensor(x), list(self._NHWC_ORDER) ) @@ -255,9 +256,8 @@ def call(self, graph_module): args=(node_replacement, list(self._NHWC_INVERSE_ORDER)), from_node=node, ) - post_permute.meta["val"] = exir_ops.edge.aten.permute_copy.default( - node_replacement_fake, - list(self._NHWC_INVERSE_ORDER), + post_permute.meta["val"] = permute_fake_tensor_metadata( + node_replacement_fake, self._NHWC_INVERSE_ORDER ) node.replace_all_uses_with(post_permute) graph_module.graph.erase_node(node)