diff --git a/backends/arm/_passes/decompose_int_pow_pass.py b/backends/arm/_passes/decompose_int_pow_pass.py index 2df8d3b2522..bb29d34d6bf 100644 --- a/backends/arm/_passes/decompose_int_pow_pass.py +++ b/backends/arm/_passes/decompose_int_pow_pass.py @@ -32,12 +32,16 @@ def call_operator(self, op, args, kwargs, meta): x = args[0] exp = args[1] - # Handle zero first and return early if exp == 0: - # return a tensor of ones with the same shape as x - return super().call_operator( + zeros = super().call_operator( + exir_ops.edge.aten.sub.Tensor, (x, x), {}, meta, True + ) + ones = super().call_operator( exir_ops.edge.aten.full_like.default, (x, 1), {}, meta, True ) + return super().call_operator( + exir_ops.edge.aten.add.Tensor, (zeros, ones), {}, meta, True + ) if not isinstance(exp, int): return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/test/ops/test_pow.py b/backends/arm/test/ops/test_pow.py index 2d007fa7e68..6d304ce0627 100644 --- a/backends/arm/test/ops/test_pow.py +++ b/backends/arm/test/ops/test_pow.py @@ -147,7 +147,6 @@ def test_pow_tensor_tensor_vgf_no_quant(test_data: Pow_TensorTensor.input_t): x_fail_FP = { "exp_two": "TOSA constraints: If x <0 .", - "exp_zero": "MLETORCH-2041 : Invalid inputs.", } diff --git a/backends/arm/test/passes/test_decompose_int_pow_pass.py b/backends/arm/test/passes/test_decompose_int_pow_pass.py index 6846392f248..7761c031e2c 100644 --- a/backends/arm/test/passes/test_decompose_int_pow_pass.py +++ b/backends/arm/test/passes/test_decompose_int_pow_pass.py @@ -59,18 +59,18 @@ def get_inputs(self) -> input_t: def test_decompose_int_pow_tosa_FP(data: TestParam) -> None: module_with_inputs, nbr_muls = data module = cast(torch.nn.Module, module_with_inputs) + pow_op = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar" pipeline = PassPipeline[input_t]( module, module_with_inputs.get_inputs(), quantize=False, ops_before_pass={ - "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 1, + pow_op: 1, }, ops_not_before_pass=[], ops_after_pass={ "executorch_exir_dialects_edge__ops_aten_mul_Tensor": nbr_muls, }, - ops_not_after_pass=["executorch_exir_dialects_edge__ops_pow_Tensor_Scalar"], pass_list=[DecomposeIntPowPass], ) pipeline.run()