diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 145e953394cd..28b125eec0b0 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -244,15 +244,20 @@ def __init__(self, model, subgraph, exp_tab, ctx): "STABLEHLO_ADD": functools.partial(self._convert_stablehlo_binary, relax_op=_op.add), "STABLEHLO_AND": self._convert_stablehlo_and, "STABLEHLO_BROADCAST_IN_DIM": self._convert_stablehlo_broadcast_in_dim, + "STABLEHLO_CBRT": self._convert_stablehlo_cbrt, "STABLEHLO_CLAMP": self._convert_stablehlo_clamp, "STABLEHLO_COMPARE": self._convert_stablehlo_compare, + "STABLEHLO_COMPOSITE": self._convert_stablehlo_composite, "STABLEHLO_CONCATENATE": self._convert_stablehlo_concatenate, + "STABLEHLO_CONVOLUTION": self._convert_stablehlo_convolution, "STABLEHLO_CONVERT": self._convert_stablehlo_convert, "STABLEHLO_COSINE": functools.partial(self._convert_stablehlo_unary, relax_op=_op.cos), "STABLEHLO_DIVIDE": functools.partial( self._convert_stablehlo_binary, relax_op=_op.divide ), + "STABLEHLO_DOT_GENERAL": self._convert_stablehlo_dot_general, "STABLEHLO_DYNAMIC_SLICE": self._convert_stablehlo_dynamic_slice, + "STABLEHLO_DYNAMIC_UPDATE_SLICE": self._convert_stablehlo_dynamic_update_slice, "STABLEHLO_EXPONENTIAL": functools.partial( self._convert_stablehlo_unary, relax_op=_op.exp ), @@ -280,13 +285,18 @@ def __init__(self, model, subgraph, exp_tab, ctx): "STABLEHLO_POWER": functools.partial( self._convert_stablehlo_binary, relax_op=_op.power ), + "STABLEHLO_REDUCE": self._convert_stablehlo_reduce, + "STABLEHLO_REDUCE_WINDOW": self._convert_stablehlo_reduce_window, + "STABLEHLO_REMAINDER": self._convert_stablehlo_remainder, "STABLEHLO_RSQRT": functools.partial(self._convert_stablehlo_unary, relax_op=_op.rsqrt), + "STABLEHLO_SCATTER": self._convert_stablehlo_scatter, "STABLEHLO_SELECT": functools.partial( self._convert_stablehlo_ternary, relax_op=_op.where ), "STABLEHLO_SHIFT_LEFT": functools.partial( self._convert_stablehlo_binary, relax_op=_op.left_shift ), + "STABLEHLO_SORT": self._convert_stablehlo_sort, "STABLEHLO_SUBTRACT": functools.partial( self._convert_stablehlo_binary, relax_op=_op.subtract ), @@ -1483,6 +1493,413 @@ def _get_stablehlo_options(self, op, options_cls): result.Init(op_options.Bytes, op_options.Pos) return result + def _get_static_tensor_shape(self, tensor, op_name): + """Return a statically-known TFLite tensor shape as Python ints.""" + try: + return [int(dim) for dim in self.get_tensor_shape(tensor)] + except (TypeError, ValueError) as err: + raise tvm.error.OpNotImplemented( + f"{op_name} requires statically-known tensor shapes" + ) from err + + def _get_stablehlo_i64_vector(self, vector, default): + """Convert an optional StableHLO int64 vector field to a Python int list.""" + if vector is None or isinstance(vector, int): + return list(default) + return [int(v) for v in vector] + + def _ensure_stablehlo_float_dtype(self, expr, op_name): + """Return expr dtype if the StableHLO subset supports it.""" + dtype = expr.struct_info.dtype + if not dtype.startswith("float"): + raise tvm.error.OpNotImplemented(f"{op_name} with dtype {dtype} is not supported") + return dtype + + def _convert_stablehlo_cbrt(self, op): + """Convert STABLEHLO_CBRT to a sign-preserving Relax expression.""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + assert len(self.get_output_tensors(op)) == 1 + + data = self.get_tensor_expr(input_tensors[0]) + dtype = self._ensure_stablehlo_float_dtype(data, "STABLEHLO_CBRT") + zero = relax.const(0, dtype) + exponent = relax.const(1.0 / 3.0, dtype) + + is_negative = self.bb.normalize(relax.op.less(data, zero)) + negative_base = self.bb.normalize(relax.op.negative(data)) + negative_root = self.bb.normalize(relax.op.power(negative_base, exponent)) + negative_result = self.bb.normalize(relax.op.negative(negative_root)) + positive_result = self.bb.normalize(relax.op.power(data, exponent)) + return self.bb.normalize(relax.op.where(is_negative, negative_result, positive_result)) + + def _convert_stablehlo_remainder(self, op): + """Convert STABLEHLO_REMAINDER to truncating remainder for float tensors.""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + assert len(self.get_output_tensors(op)) == 1 + + lhs = self.get_tensor_expr(input_tensors[0]) + rhs = self.get_tensor_expr(input_tensors[1]) + self._ensure_stablehlo_float_dtype(lhs, "STABLEHLO_REMAINDER") + self._ensure_stablehlo_float_dtype(rhs, "STABLEHLO_REMAINDER") + + quotient = self.bb.normalize(relax.op.divide(lhs, rhs)) + truncated = self.bb.normalize(relax.op.trunc(quotient)) + product = self.bb.normalize(relax.op.multiply(rhs, truncated)) + return self.bb.normalize(relax.op.subtract(lhs, product)) + + def _get_stablehlo_simple_body_op(self, body_subgraph_index, parent_op_name, input_count): + """Return the single operator from a simple StableHLO body subgraph.""" + if body_subgraph_index <= 0 or body_subgraph_index >= self.model.SubgraphsLength(): + raise tvm.error.OpNotImplemented( + f"{parent_op_name} requires a valid non-main body subgraph" + ) + + body_subgraph = self.model.Subgraphs(body_subgraph_index) + if ( + body_subgraph.InputsLength() != input_count + or body_subgraph.OutputsLength() != 1 + or body_subgraph.OperatorsLength() != 1 + ): + raise tvm.error.OpNotImplemented( + f"{parent_op_name} only supports single-op body subgraphs" + ) + + return body_subgraph.Operators(0) + + def _check_stablehlo_reduce_init( + self, init_tensor, reducer_name, parent_op_name="STABLEHLO_REDUCE" + ): + """Validate that the StableHLO reduce init value matches the Relax identity.""" + if self.has_expr(init_tensor.tensor_idx): + raise tvm.error.OpNotImplemented( + f"{parent_op_name} with dynamic init values is not supported" + ) + + init_value = np.asarray(self.get_tensor_value(init_tensor)) + if init_value.shape not in [(), (1,)]: + raise tvm.error.OpNotImplemented(f"{parent_op_name} requires scalar init values") + + dtype = init_value.dtype + scalar = init_value.item() + if reducer_name == "STABLEHLO_ADD": + is_identity = bool(np.isclose(scalar, 0)) + elif reducer_name == "STABLEHLO_MULTIPLY": + is_identity = bool(np.isclose(scalar, 1)) + elif reducer_name == "STABLEHLO_MAXIMUM": + if np.issubdtype(dtype, np.floating): + is_identity = bool(np.isneginf(scalar)) + elif np.issubdtype(dtype, np.integer): + is_identity = scalar == np.iinfo(dtype).min + else: + is_identity = False + elif reducer_name == "STABLEHLO_MINIMUM": + if np.issubdtype(dtype, np.floating): + is_identity = bool(np.isposinf(scalar)) + elif np.issubdtype(dtype, np.integer): + is_identity = scalar == np.iinfo(dtype).max + else: + is_identity = False + else: + raise tvm.error.OpNotImplemented( + f"{parent_op_name} reducer {reducer_name} is not supported" + ) + + if not is_identity: + raise tvm.error.OpNotImplemented( + f"{parent_op_name} init value must match the reducer identity" + ) + + def _convert_stablehlo_reduce(self, op): + """Convert the single-input STABLEHLO_REDUCE subset to Relax reductions.""" + from tflite.StablehloReduceOptions import StablehloReduceOptions + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + assert len(self.get_output_tensors(op)) == 1 + + opts = self._get_stablehlo_options(op, StablehloReduceOptions) + dimensions = self._get_stablehlo_i64_vector(opts.DimensionsAsNumpy(), []) + body_op = self._get_stablehlo_simple_body_op( + int(opts.BodySubgraphIndex()), "STABLEHLO_REDUCE", 2 + ) + reducer_name = self.get_op_code_str(body_op) + + reducers = { + "STABLEHLO_ADD": relax.op.sum, + "STABLEHLO_MAXIMUM": relax.op.max, + "STABLEHLO_MINIMUM": relax.op.min, + "STABLEHLO_MULTIPLY": relax.op.prod, + } + if reducer_name not in reducers: + raise tvm.error.OpNotImplemented( + f"STABLEHLO_REDUCE reducer {reducer_name} is not supported" + ) + + self._check_stablehlo_reduce_init(input_tensors[1], reducer_name) + data = self.get_tensor_expr(input_tensors[0]) + return self.bb.normalize(reducers[reducer_name](data, axis=dimensions, keepdims=False)) + + def _convert_stablehlo_reduce_window(self, op): + """Convert the NHWC 2D max-pool STABLEHLO_REDUCE_WINDOW subset.""" + from tflite.StablehloReduceWindowOptions import StablehloReduceWindowOptions + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + assert len(self.get_output_tensors(op)) == 1 + + opts = self._get_stablehlo_options(op, StablehloReduceWindowOptions) + body_op = self._get_stablehlo_simple_body_op( + int(opts.BodySubgraphIndex()), "STABLEHLO_REDUCE_WINDOW", 2 + ) + reducer_name = self.get_op_code_str(body_op) + if reducer_name != "STABLEHLO_MAXIMUM": + raise tvm.error.OpNotImplemented( + "STABLEHLO_REDUCE_WINDOW only supports MAXIMUM reducer windows" + ) + self._check_stablehlo_reduce_init(input_tensors[1], reducer_name, "STABLEHLO_REDUCE_WINDOW") + + data_shape = self._get_static_tensor_shape(input_tensors[0], "STABLEHLO_REDUCE_WINDOW") + if len(data_shape) != 4: + raise tvm.error.OpNotImplemented("STABLEHLO_REDUCE_WINDOW only supports 4D input") + + window_dimensions = self._get_stablehlo_i64_vector(opts.WindowDimensionsAsNumpy(), []) + window_strides = self._get_stablehlo_i64_vector( + opts.WindowStridesAsNumpy(), [1] * len(window_dimensions) + ) + base_dilations = self._get_stablehlo_i64_vector( + opts.BaseDilationsAsNumpy(), [1] * len(window_dimensions) + ) + window_dilations = self._get_stablehlo_i64_vector( + opts.WindowDilationsAsNumpy(), [1] * len(window_dimensions) + ) + padding = self._get_stablehlo_i64_vector( + opts.PaddingAsNumpy(), [0] * (2 * len(window_dimensions)) + ) + + if ( + len(window_dimensions) != 4 + or len(window_strides) != 4 + or len(base_dilations) != 4 + or len(window_dilations) != 4 + or len(padding) != 8 + ): + raise tvm.error.OpNotImplemented( + "STABLEHLO_REDUCE_WINDOW only supports rank-4 window attributes" + ) + if window_dimensions[0] != 1 or window_dimensions[3] != 1: + raise tvm.error.OpNotImplemented( + "STABLEHLO_REDUCE_WINDOW only supports pooling over spatial dimensions" + ) + if window_strides[0] != 1 or window_strides[3] != 1: + raise tvm.error.OpNotImplemented( + "STABLEHLO_REDUCE_WINDOW only supports unit batch/channel strides" + ) + if base_dilations != [1, 1, 1, 1]: + raise tvm.error.OpNotImplemented( + "STABLEHLO_REDUCE_WINDOW with base dilation is not supported" + ) + if padding[0] != 0 or padding[1] != 0 or padding[6] != 0 or padding[7] != 0: + raise tvm.error.OpNotImplemented( + "STABLEHLO_REDUCE_WINDOW only supports spatial padding" + ) + + data = self.get_tensor_expr(input_tensors[0]) + return self.bb.normalize( + relax.op.nn.max_pool2d( + data, + pool_size=[window_dimensions[1], window_dimensions[2]], + strides=[window_strides[1], window_strides[2]], + padding=[padding[2], padding[4], padding[3], padding[5]], + dilation=[window_dilations[1], window_dilations[2]], + layout="NHWC", + out_layout="NHWC", + ) + ) + + def _convert_stablehlo_scatter(self, op): + """Convert the canonical point-update STABLEHLO_SCATTER subset.""" + from tflite.StablehloScatterOptions import StablehloScatterOptions + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be 3" + assert len(self.get_output_tensors(op)) == 1 + + opts = self._get_stablehlo_options(op, StablehloScatterOptions) + operand_shape = self._get_static_tensor_shape(input_tensors[0], "STABLEHLO_SCATTER") + indices_shape = self._get_static_tensor_shape(input_tensors[1], "STABLEHLO_SCATTER") + updates_shape = self._get_static_tensor_shape(input_tensors[2], "STABLEHLO_SCATTER") + operand_rank = len(operand_shape) + indices_rank = len(indices_shape) + + update_window_dims = self._get_stablehlo_i64_vector(opts.UpdateWindowDimsAsNumpy(), []) + inserted_window_dims = self._get_stablehlo_i64_vector(opts.InsertedWindowDimsAsNumpy(), []) + scatter_dims_to_operand_dims = self._get_stablehlo_i64_vector( + opts.ScatterDimsToOperandDimsAsNumpy(), [] + ) + index_vector_dim = int(opts.IndexVectorDim()) + + if indices_rank == 0 or index_vector_dim != indices_rank - 1: + raise tvm.error.OpNotImplemented( + "STABLEHLO_SCATTER only supports trailing index-vector dimensions" + ) + if update_window_dims: + raise tvm.error.OpNotImplemented( + "STABLEHLO_SCATTER only supports point updates without update windows" + ) + if inserted_window_dims != list(range(operand_rank)): + raise tvm.error.OpNotImplemented( + "STABLEHLO_SCATTER only supports point updates for every operand dimension" + ) + if scatter_dims_to_operand_dims != list(range(operand_rank)): + raise tvm.error.OpNotImplemented( + "STABLEHLO_SCATTER only supports canonical scatter-to-operand dimensions" + ) + if indices_shape[-1] != operand_rank or updates_shape != indices_shape[:-1]: + raise tvm.error.OpNotImplemented( + "STABLEHLO_SCATTER requires point update shapes to match scatter indices" + ) + + body_op = self._get_stablehlo_simple_body_op( + int(opts.UpdateComputationSubgraphIndex()), "STABLEHLO_SCATTER", 2 + ) + reducer_name = self.get_op_code_str(body_op) + reductions = { + "STABLEHLO_ADD": "add", + "STABLEHLO_MAXIMUM": "max", + "STABLEHLO_MINIMUM": "min", + "STABLEHLO_MULTIPLY": "mul", + } + if reducer_name not in reductions: + raise tvm.error.OpNotImplemented( + f"STABLEHLO_SCATTER reducer {reducer_name} is not supported" + ) + + operand = self.get_tensor_expr(input_tensors[0]) + indices = self.get_tensor_expr(input_tensors[1]) + updates = self.get_tensor_expr(input_tensors[2]) + return self.bb.normalize( + relax.op.scatter_nd(operand, indices, updates, reductions[reducer_name]) + ) + + def _convert_stablehlo_composite(self, op): + """Convert STABLEHLO_COMPOSITE by inlining a simple decomposition subgraph.""" + from tflite.StableHLOCompositeOptions import StableHLOCompositeOptions + + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(output_tensors) != 1: + raise tvm.error.OpNotImplemented( + "STABLEHLO_COMPOSITE only supports single-output decompositions" + ) + + opts = self._get_stablehlo_options(op, StableHLOCompositeOptions) + composite_name = opts.Name() + composite_name = ( + composite_name.decode("utf-8") if composite_name is not None else "" + ) + if opts.CompositeAttributesLength() != 0: + raise tvm.error.OpNotImplemented( + f"STABLEHLO_COMPOSITE {composite_name} with composite attributes is not supported" + ) + + decomposition_subgraph_index = int(opts.DecompositionSubgraphIndex()) + if ( + decomposition_subgraph_index <= 0 + or decomposition_subgraph_index >= self.model.SubgraphsLength() + ): + raise tvm.error.OpNotImplemented( + f"STABLEHLO_COMPOSITE {composite_name} requires a valid decomposition subgraph" + ) + decomposition_subgraph = self.model.Subgraphs(decomposition_subgraph_index) + if decomposition_subgraph.InputsLength() != len(input_tensors): + raise tvm.error.OpNotImplemented( + f"STABLEHLO_COMPOSITE {composite_name} decomposition input count mismatch" + ) + if decomposition_subgraph.OutputsLength() != 1: + raise tvm.error.OpNotImplemented( + f"STABLEHLO_COMPOSITE {composite_name} only supports single-output decompositions" + ) + + decomposition_exp_tab = ExprTable() + decomposition_converter = OperatorConverter( + self.model, decomposition_subgraph, decomposition_exp_tab, self.bb + ) + for decomposition_input_idx, composite_input in zip( + decomposition_subgraph.InputsAsNumpy(), input_tensors + ): + decomposition_input_name = get_tensor_name( + decomposition_subgraph, int(decomposition_input_idx) + ) + decomposition_exp_tab.set_expr( + decomposition_input_name, + self.get_tensor_expr(composite_input), + force_override=True, + ) + + decomposition_converter.check_unsupported_ops() + decomposition_converter.convert_op_to_relax() + decomposition_output_idx = int(decomposition_subgraph.Outputs(0)) + decomposition_output_tensor = decomposition_converter.get_tensors( + [decomposition_output_idx] + )[0] + for const_expr, value in decomposition_exp_tab.params.values(): + param_name = f"_param_{self.exp_tab.const_ctr}" + self.exp_tab.const_ctr += 1 + self.exp_tab.params[param_name] = (const_expr, value) + return decomposition_converter.get_tensor_expr(decomposition_output_tensor) + + def _convert_stablehlo_sort(self, op): + """Convert the single-input STABLEHLO_SORT subset to Relax sort.""" + from tflite.StablehloCompareOptions import StablehloCompareOptions + from tflite.StablehloComparisonDirection import StablehloComparisonDirection + from tflite.StablehloComparisonType import StablehloComparisonType + from tflite.StablehloSortOptions import StablehloSortOptions + + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 1 or len(output_tensors) != 1: + raise tvm.error.OpNotImplemented( + "STABLEHLO_SORT only supports single-input single-output sort" + ) + + opts = self._get_stablehlo_options(op, StablehloSortOptions) + if opts.IsStable(): + raise tvm.error.OpNotImplemented("STABLEHLO_SORT stable sort is not supported") + + body_op = self._get_stablehlo_simple_body_op( + int(opts.ComparatorSubgraphIndex()), "STABLEHLO_SORT", 2 + ) + comparator_name = self.get_op_code_str(body_op) + if comparator_name != "STABLEHLO_COMPARE": + raise tvm.error.OpNotImplemented( + f"STABLEHLO_SORT comparator {comparator_name} is not supported" + ) + + compare_opts = self._get_stablehlo_options(body_op, StablehloCompareOptions) + if ( + compare_opts.CompareType() + == StablehloComparisonType.STABLEHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER + ): + raise tvm.error.OpNotImplemented( + "STABLEHLO_SORT with TOTALORDER comparator is not supported" + ) + + direction = compare_opts.ComparisonDirection() + _DIR = StablehloComparisonDirection + if direction == _DIR.STABLEHLO_COMPARISON_DIRECTION_LT: + descending = False + elif direction == _DIR.STABLEHLO_COMPARISON_DIRECTION_GT: + descending = True + else: + raise tvm.error.OpNotImplemented("STABLEHLO_SORT only supports LT or GT comparators") + + data = self.get_tensor_expr(input_tensors[0]) + return self.bb.normalize( + relax.op.sort(data, axis=int(opts.Dimension()), descending=descending) + ) + def _convert_stablehlo_convert(self, op): """Convert STABLEHLO_CONVERT to Relax (astype). @@ -1719,6 +2136,189 @@ def _const_1d(values, dtype="int64"): return self.bb.normalize(relax.op.dynamic_strided_slice(operand, begin, end, strides)) + def _convert_stablehlo_dynamic_update_slice(self, op): + """Convert STABLEHLO_DYNAMIC_UPDATE_SLICE to Relax for static starts.""" + input_tensors = self.get_input_tensors(op) + # operand + update + N start-index scalars + assert len(input_tensors) >= 3, "input tensors length should be >= 3" + assert len(self.get_output_tensors(op)) == 1 + + operand_tensor = input_tensors[0] + update_tensor = input_tensors[1] + start_tensors = input_tensors[2:] + + op_name = "STABLEHLO_DYNAMIC_UPDATE_SLICE" + operand_shape = self._get_static_tensor_shape(operand_tensor, op_name) + update_shape = self._get_static_tensor_shape(update_tensor, op_name) + rank = len(operand_shape) + if len(update_shape) != rank or len(start_tensors) != rank: + raise tvm.error.OpNotImplemented( + "STABLEHLO_DYNAMIC_UPDATE_SLICE requires operand, update, " + "and start-index ranks to match" + ) + + if any(self.has_expr(t.tensor_idx) for t in start_tensors): + raise tvm.error.OpNotImplemented( + "STABLEHLO_DYNAMIC_UPDATE_SLICE with dynamic start indices is not supported" + ) + + start_vals = [int(np.asarray(self.get_tensor_value(t)).item()) for t in start_tensors] + for start, size, dim in zip(start_vals, update_shape, operand_shape): + if start < 0 or start + size > dim: + raise tvm.error.OpNotImplemented( + "STABLEHLO_DYNAMIC_UPDATE_SLICE with out-of-bounds update " + "indices is not supported" + ) + + update_indices = np.indices(update_shape, dtype=np.int64) + for axis, start in enumerate(start_vals): + update_indices[axis] += start + update_indices = np.moveaxis(update_indices, 0, -1) + + operand = self.get_tensor_expr(operand_tensor) + update = self.get_tensor_expr(update_tensor) + indices = self.bb.normalize(relax.const(update_indices, dtype="int64")) + return self.bb.normalize(relax.op.scatter_nd(operand, indices, update, "update")) + + def _convert_stablehlo_dot_general(self, op): + """Convert the canonical 2D STABLEHLO_DOT_GENERAL subset to Relax matmul.""" + from tflite.StablehloDotGeneralOptions import StablehloDotGeneralOptions + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + assert len(self.get_output_tensors(op)) == 1 + + opts = self._get_stablehlo_options(op, StablehloDotGeneralOptions) + lhs_batch_dims = self._get_stablehlo_i64_vector(opts.LhsBatchingDimensionsAsNumpy(), []) + rhs_batch_dims = self._get_stablehlo_i64_vector(opts.RhsBatchingDimensionsAsNumpy(), []) + lhs_contract_dims = self._get_stablehlo_i64_vector( + opts.LhsContractingDimensionsAsNumpy(), [] + ) + rhs_contract_dims = self._get_stablehlo_i64_vector( + opts.RhsContractingDimensionsAsNumpy(), [] + ) + + lhs_shape = self._get_static_tensor_shape(input_tensors[0], "STABLEHLO_DOT_GENERAL") + rhs_shape = self._get_static_tensor_shape(input_tensors[1], "STABLEHLO_DOT_GENERAL") + if len(lhs_shape) != 2 or len(rhs_shape) != 2: + raise tvm.error.OpNotImplemented("STABLEHLO_DOT_GENERAL only supports 2D matmul") + if lhs_batch_dims or rhs_batch_dims: + raise tvm.error.OpNotImplemented( + "STABLEHLO_DOT_GENERAL with batching dimensions is not supported" + ) + if lhs_contract_dims != [1] or rhs_contract_dims != [0]: + raise tvm.error.OpNotImplemented( + "STABLEHLO_DOT_GENERAL only supports canonical contracting dimensions" + ) + + lhs = self.get_tensor_expr(input_tensors[0]) + rhs = self.get_tensor_expr(input_tensors[1]) + return self.bb.normalize(relax.op.matmul(lhs, rhs)) + + def _convert_stablehlo_convolution(self, op): + """Convert the canonical 2D NHWC/HWIO STABLEHLO_CONVOLUTION subset.""" + from tflite.StablehloConvolutionOptions import StablehloConvolutionOptions + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + assert len(self.get_output_tensors(op)) == 1 + + opts = self._get_stablehlo_options(op, StablehloConvolutionOptions) + input_spatial_dims = self._get_stablehlo_i64_vector( + opts.InputSpatialDimensionsAsNumpy(), [] + ) + kernel_spatial_dims = self._get_stablehlo_i64_vector( + opts.KernelSpatialDimensionsAsNumpy(), [] + ) + output_spatial_dims = self._get_stablehlo_i64_vector( + opts.OutputSpatialDimensionsAsNumpy(), [] + ) + if input_spatial_dims != [1, 2]: + raise tvm.error.OpNotImplemented( + "STABLEHLO_CONVOLUTION only supports NHWC input layout" + ) + if kernel_spatial_dims != [0, 1]: + raise tvm.error.OpNotImplemented( + "STABLEHLO_CONVOLUTION only supports HWIO kernel layout" + ) + if output_spatial_dims != [1, 2]: + raise tvm.error.OpNotImplemented( + "STABLEHLO_CONVOLUTION only supports NHWC output layout" + ) + + if ( + int(opts.InputBatchDimension()) != 0 + or int(opts.InputFeatureDimension()) != 3 + or int(opts.KernelInputFeatureDimension()) != 2 + or int(opts.KernelOutputFeatureDimension()) != 3 + or int(opts.OutputBatchDimension()) != 0 + or int(opts.OutputFeatureDimension()) != 3 + ): + raise tvm.error.OpNotImplemented( + "STABLEHLO_CONVOLUTION only supports canonical NHWC/HWIO dimension numbers" + ) + if int(opts.BatchGroupCount()) != 1: + raise tvm.error.OpNotImplemented( + "STABLEHLO_CONVOLUTION with batch_group_count > 1 is not supported" + ) + if int(opts.FeatureGroupCount()) != 1: + raise tvm.error.OpNotImplemented( + "STABLEHLO_CONVOLUTION with feature_group_count > 1 is not supported" + ) + + data_shape = self._get_static_tensor_shape(input_tensors[0], "STABLEHLO_CONVOLUTION") + kernel_shape = self._get_static_tensor_shape(input_tensors[1], "STABLEHLO_CONVOLUTION") + if len(data_shape) != 4 or len(kernel_shape) != 4: + raise tvm.error.OpNotImplemented("STABLEHLO_CONVOLUTION only supports 2D convolution") + if data_shape[3] != kernel_shape[2]: + raise tvm.error.OpNotImplemented( + "STABLEHLO_CONVOLUTION input channels must match kernel input channels" + ) + + window_strides = self._get_stablehlo_i64_vector(opts.WindowStridesAsNumpy(), [1, 1]) + padding = self._get_stablehlo_i64_vector(opts.PaddingAsNumpy(), [0, 0, 0, 0]) + lhs_dilation = self._get_stablehlo_i64_vector(opts.LhsDilationAsNumpy(), [1, 1]) + rhs_dilation = self._get_stablehlo_i64_vector(opts.RhsDilationAsNumpy(), [1, 1]) + window_reversal = opts.WindowReversalAsNumpy() + window_reversal = ( + [False, False] if window_reversal is None else [bool(v) for v in window_reversal] + ) + + if len(window_strides) != 2 or len(rhs_dilation) != 2: + raise tvm.error.OpNotImplemented( + "STABLEHLO_CONVOLUTION only supports two spatial dimensions" + ) + if lhs_dilation != [1, 1]: + raise tvm.error.OpNotImplemented( + "STABLEHLO_CONVOLUTION with lhs dilation is not supported" + ) + if any(window_reversal): + raise tvm.error.OpNotImplemented( + "STABLEHLO_CONVOLUTION with window reversal is not supported" + ) + if len(padding) != 4: + raise tvm.error.OpNotImplemented( + "STABLEHLO_CONVOLUTION only supports 2D low/high padding" + ) + + # StableHLO stores padding as [low_h, high_h, low_w, high_w]. + relax_padding = [padding[0], padding[2], padding[1], padding[3]] + data = self.get_tensor_expr(input_tensors[0]) + kernel = self.get_tensor_expr(input_tensors[1]) + self._ensure_stablehlo_float_dtype(data, "STABLEHLO_CONVOLUTION") + self._ensure_stablehlo_float_dtype(kernel, "STABLEHLO_CONVOLUTION") + return self.bb.normalize( + relax.op.nn.conv2d( + data, + kernel, + strides=window_strides, + padding=relax_padding, + dilation=rhs_dilation, + data_layout="NHWC", + kernel_layout="HWIO", + ) + ) + def _convert_stablehlo_gather(self, op): """Convert STABLEHLO_GATHER to Relax (take-equivalent subset only). @@ -5528,19 +6128,18 @@ def _input_type(model): assert subgraph_count > 0 shape_dict = {} dtype_dict = {} - for subgraph_index in range(subgraph_count): - subgraph = model.Subgraphs(subgraph_index) - inputs_count = subgraph.InputsLength() - # TFLite subgraphs can validly have zero inputs (e.g. constant-only RANGE models). - for input_index in range(inputs_count): - input_ = subgraph.Inputs(input_index) - assert subgraph.TensorsLength() > input_ - tensor = subgraph.Tensors(input_) - input_shape = tuple(tensor.ShapeAsNumpy()) - tensor_type = tensor.Type() - input_name = get_tensor_name(subgraph, input_) - shape_dict[input_name] = input_shape - dtype_dict[input_name] = _decode_type(tensor_type) + subgraph = model.Subgraphs(0) + inputs_count = subgraph.InputsLength() + # TFLite subgraphs can validly have zero inputs (e.g. constant-only RANGE models). + for input_index in range(inputs_count): + input_ = subgraph.Inputs(input_index) + assert subgraph.TensorsLength() > input_ + tensor = subgraph.Tensors(input_) + input_shape = tuple(tensor.ShapeAsNumpy()) + tensor_type = tensor.Type() + input_name = get_tensor_name(subgraph, input_) + shape_dict[input_name] = input_shape + dtype_dict[input_name] = _decode_type(tensor_type) return shape_dict, dtype_dict @@ -5652,8 +6251,10 @@ def func(self, data): if dtype_dict is not None: _dtype_dict.update(dtype_dict) - # keep the same as tflite - assert model.SubgraphsLength() == 1, "only support one subgraph (main subgraph)" + # Only Subgraphs(0) is converted into Relax main. Additional subgraphs are + # region bodies referenced by specific TFLite ops and are consumed by those + # op converters as needed. + assert model.SubgraphsLength() >= 1, "TFLite model must contain at least one subgraph" subgraph = model.Subgraphs(0) # model inputs / outputs diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index bb2fb0bfa74a..031c1553d8bf 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3677,6 +3677,9 @@ def _get_tflite_schema_enum(enum_name): # ── StableHLO BuiltinOptions2 schema modules ──────────────────────────── _tfl_stablehlo_concat_opts = _get_tflite_schema_module("StablehloConcatenateOptions") _tfl_stablehlo_bcast_opts = _get_tflite_schema_module("StablehloBroadcastInDimOptions") +_tfl_stablehlo_composite_opts = _get_tflite_schema_module("StableHLOCompositeOptions") +_tfl_stablehlo_conv_opts = _get_tflite_schema_module("StablehloConvolutionOptions") +_tfl_stablehlo_dot_opts = _get_tflite_schema_module("StablehloDotGeneralOptions") _tfl_stablehlo_iota_opts = _get_tflite_schema_module("StablehloIotaOptions") _tfl_stablehlo_compare_opts = _get_tflite_schema_module("StablehloCompareOptions") _tfl_stablehlo_comp_dir = _get_tflite_schema_module("StablehloComparisonDirection") @@ -3684,6 +3687,10 @@ def _get_tflite_schema_enum(enum_name): _tfl_stablehlo_pad_opts = _get_tflite_schema_module("StablehloPadOptions") _tfl_stablehlo_dyn_slice_opts = _get_tflite_schema_module("StablehloDynamicSliceOptions") _tfl_stablehlo_gather_opts = _get_tflite_schema_module("StablehloGatherOptions") +_tfl_stablehlo_reduce_opts = _get_tflite_schema_module("StablehloReduceOptions") +_tfl_stablehlo_reduce_window_opts = _get_tflite_schema_module("StablehloReduceWindowOptions") +_tfl_stablehlo_scatter_opts = _get_tflite_schema_module("StablehloScatterOptions") +_tfl_stablehlo_sort_opts = _get_tflite_schema_module("StablehloSortOptions") _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata") _tfl_fully_connected_options = _get_tflite_schema_module("FullyConnectedOptions") _tfl_int32_vector = _get_tflite_schema_module("Int32Vector") @@ -3721,6 +3728,20 @@ def _tflite_int32_vector(builder, start_vector_fn, values): return builder.EndVector() +def _tflite_int64_vector(builder, start_vector_fn, values): + start_vector_fn(builder, len(values)) + for value in reversed(values): + builder.PrependInt64(value) + return builder.EndVector() + + +def _tflite_bool_vector(builder, start_vector_fn, values): + start_vector_fn(builder, len(values)) + for value in reversed(values): + builder.PrependBool(value) + return builder.EndVector() + + def _tflite_offset_vector(builder, start_vector_fn, offsets): start_vector_fn(builder, len(offsets)) for offset in reversed(offsets): @@ -3834,12 +3855,15 @@ def _build_subgraph(builder, *, tensors, operators, inputs, outputs): return _tfl_subgraph.SubGraphEnd(builder) -def _finish_tflite_model(builder, *, subgraph, operator_codes, buffers): +def _finish_tflite_model(builder, *, subgraph, operator_codes, buffers, extra_subgraphs=None): + all_subgraphs = [subgraph] + (extra_subgraphs or []) buffers_vec = _tflite_offset_vector(builder, _tfl_model.ModelStartBuffersVector, buffers) opcodes_vec = _tflite_offset_vector( builder, _tfl_model.ModelStartOperatorCodesVector, operator_codes ) - subgraphs_vec = _tflite_offset_vector(builder, _tfl_model.ModelStartSubgraphsVector, [subgraph]) + subgraphs_vec = _tflite_offset_vector( + builder, _tfl_model.ModelStartSubgraphsVector, all_subgraphs + ) _tfl_model.ModelStart(builder) _tfl_model.ModelAddBuffers(builder, buffers_vec) @@ -3896,6 +3920,453 @@ def _build_stablehlo_model(*, builtin_name, input_count): ) +def _build_stablehlo_model_with_unused_subgraph(): + """Build a StableHLO model with an unused extra subgraph.""" + builder = flatbuffers.Builder(1024) + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_ADD") + + main_tensors = [_build_tensor(builder, buffer_idx, [2, 2]) for buffer_idx in range(3)] + main_op = _build_operator(builder, 0, [0, 1], [2]) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[main_op], + inputs=[0, 1], + outputs=[2], + ) + + # Give the unused subgraph a conflicting input tensor name and different + # shape. from_tflite should infer the main function input shape only from + # Subgraphs(0). + extra_tensors = [_build_tensor(builder, buffer_idx, [4, 4]) for buffer_idx in range(3, 6)] + extra_op = _build_operator(builder, 0, [0, 1], [2]) + extra_subgraph = _build_subgraph( + builder, + tensors=extra_tensors, + operators=[extra_op], + inputs=[0, 1], + outputs=[2], + ) + + operator_codes = [_build_operator_code(builder, builtin_op)] + buffers = [_build_buffer(builder) for _ in range(6)] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[extra_subgraph], + operator_codes=operator_codes, + buffers=buffers, + ) + + +def _build_stablehlo_reduce_model(reducer_name, init_value): + """Build a single-input STABLEHLO_REDUCE model with a binary reducer body.""" + builder = flatbuffers.Builder(1024) + + dimensions_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_reduce_opts.StablehloReduceOptionsStartDimensionsVector, + [1], + ) + _tfl_stablehlo_reduce_opts.StablehloReduceOptionsStart(builder) + _tfl_stablehlo_reduce_opts.StablehloReduceOptionsAddDimensions(builder, dimensions_vec) + _tfl_stablehlo_reduce_opts.StablehloReduceOptionsAddBodySubgraphIndex(builder, 1) + reduce_opts = _tfl_stablehlo_reduce_opts.StablehloReduceOptionsEnd(builder) + + reduce_builtin = _get_stablehlo_builtin_operator("STABLEHLO_REDUCE") + reducer_builtin = _get_stablehlo_builtin_operator(reducer_name) + reduce_code = _build_operator_code(builder, reduce_builtin) + reducer_code = _build_operator_code(builder, reducer_builtin) + + main_tensors = [ + _build_tensor(builder, 0, [2, 3]), + _build_tensor(builder, 1, []), + _build_tensor(builder, 2, [2]), + ] + reduce_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options2_type=_tfl_builtin_options2.StablehloReduceOptions, + builtin_options2=reduce_opts, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[reduce_op], + inputs=[0], + outputs=[2], + ) + + body_tensors = [_build_tensor(builder, buffer_idx, []) for buffer_idx in range(3, 6)] + reducer_op = _build_operator(builder, 1, [0, 1], [2]) + body_subgraph = _build_subgraph( + builder, + tensors=body_tensors, + operators=[reducer_op], + inputs=[0, 1], + outputs=[2], + ) + + buffers = [ + _build_buffer(builder), + _build_buffer(builder, np.array(init_value, dtype=np.float32).tobytes()), + _build_buffer(builder), + _build_buffer(builder), + _build_buffer(builder), + _build_buffer(builder), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[body_subgraph], + operator_codes=[reduce_code, reducer_code], + buffers=buffers, + ) + + +def _build_stablehlo_sort_model(comparison_direction, is_stable=False): + """Build a single-input STABLEHLO_SORT model with a compare body.""" + builder = flatbuffers.Builder(1024) + + _tfl_stablehlo_sort_opts.StablehloSortOptionsStart(builder) + _tfl_stablehlo_sort_opts.StablehloSortOptionsAddDimension(builder, 1) + _tfl_stablehlo_sort_opts.StablehloSortOptionsAddIsStable(builder, is_stable) + _tfl_stablehlo_sort_opts.StablehloSortOptionsAddComparatorSubgraphIndex(builder, 1) + sort_opts = _tfl_stablehlo_sort_opts.StablehloSortOptionsEnd(builder) + + _tfl_stablehlo_compare_opts.StablehloCompareOptionsStart(builder) + _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddComparisonDirection( + builder, comparison_direction + ) + compare_opts = _tfl_stablehlo_compare_opts.StablehloCompareOptionsEnd(builder) + + sort_builtin = _get_stablehlo_builtin_operator("STABLEHLO_SORT") + compare_builtin = _get_stablehlo_builtin_operator("STABLEHLO_COMPARE") + sort_code = _build_operator_code(builder, sort_builtin) + compare_code = _build_operator_code(builder, compare_builtin) + + main_tensors = [ + _build_tensor(builder, 0, [2, 3]), + _build_tensor(builder, 1, [2, 3]), + ] + sort_op = _build_operator( + builder, + 0, + [0], + [1], + builtin_options2_type=_tfl_builtin_options2.StablehloSortOptions, + builtin_options2=sort_opts, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[sort_op], + inputs=[0], + outputs=[1], + ) + + body_tensors = [ + _build_tensor(builder, 2, []), + _build_tensor(builder, 3, []), + _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.BOOL), + ] + compare_op = _build_operator( + builder, + 1, + [0, 1], + [2], + builtin_options2_type=_tfl_builtin_options2.StablehloCompareOptions, + builtin_options2=compare_opts, + ) + body_subgraph = _build_subgraph( + builder, + tensors=body_tensors, + operators=[compare_op], + inputs=[0, 1], + outputs=[2], + ) + + buffers = [_build_buffer(builder) for _ in range(5)] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[body_subgraph], + operator_codes=[sort_code, compare_code], + buffers=buffers, + ) + + +def _build_stablehlo_reduce_window_model( + reducer_name="STABLEHLO_MAXIMUM", + init_value=-np.inf, + base_dilations=None, +): + """Build an NHWC 2D STABLEHLO_REDUCE_WINDOW model.""" + builder = flatbuffers.Builder(1024) + if base_dilations is None: + base_dilations = [1, 1, 1, 1] + + window_dimensions_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartWindowDimensionsVector, + [1, 2, 2, 1], + ) + window_strides_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartWindowStridesVector, + [1, 2, 2, 1], + ) + base_dilations_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartBaseDilationsVector, + base_dilations, + ) + window_dilations_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartWindowDilationsVector, + [1, 1, 1, 1], + ) + padding_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartPaddingVector, + [0, 0, 0, 0, 0, 0, 0, 0], + ) + + _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStart(builder) + _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddWindowDimensions( + builder, window_dimensions_vec + ) + _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddWindowStrides( + builder, window_strides_vec + ) + _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddBaseDilations( + builder, base_dilations_vec + ) + _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddWindowDilations( + builder, window_dilations_vec + ) + _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddPadding(builder, padding_vec) + _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddBodySubgraphIndex(builder, 1) + reduce_window_opts = _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsEnd(builder) + + reduce_window_builtin = _get_stablehlo_builtin_operator("STABLEHLO_REDUCE_WINDOW") + reducer_builtin = _get_stablehlo_builtin_operator(reducer_name) + reduce_window_code = _build_operator_code(builder, reduce_window_builtin) + reducer_code = _build_operator_code(builder, reducer_builtin) + + main_tensors = [ + _build_tensor(builder, 0, [1, 4, 4, 1]), + _build_tensor(builder, 1, []), + _build_tensor(builder, 2, [1, 2, 2, 1]), + ] + reduce_window_op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options2_type=_tfl_builtin_options2.StablehloReduceWindowOptions, + builtin_options2=reduce_window_opts, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[reduce_window_op], + inputs=[0], + outputs=[2], + ) + + body_tensors = [_build_tensor(builder, buffer_idx, []) for buffer_idx in range(3, 6)] + reducer_op = _build_operator(builder, 1, [0, 1], [2]) + body_subgraph = _build_subgraph( + builder, + tensors=body_tensors, + operators=[reducer_op], + inputs=[0, 1], + outputs=[2], + ) + + buffers = [ + _build_buffer(builder), + _build_buffer(builder, np.array(init_value, dtype=np.float32).tobytes()), + _build_buffer(builder), + _build_buffer(builder), + _build_buffer(builder), + _build_buffer(builder), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[body_subgraph], + operator_codes=[reduce_window_code, reducer_code], + buffers=buffers, + ) + + +def _build_stablehlo_scatter_model(reducer_name="STABLEHLO_ADD", update_window_dims=None): + """Build a canonical point-update STABLEHLO_SCATTER model.""" + builder = flatbuffers.Builder(1024) + if update_window_dims is None: + update_window_dims = [] + + update_window_dims_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_scatter_opts.StablehloScatterOptionsStartUpdateWindowDimsVector, + update_window_dims, + ) + inserted_window_dims_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_scatter_opts.StablehloScatterOptionsStartInsertedWindowDimsVector, + [0], + ) + scatter_dims_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_scatter_opts.StablehloScatterOptionsStartScatterDimsToOperandDimsVector, + [0], + ) + + _tfl_stablehlo_scatter_opts.StablehloScatterOptionsStart(builder) + _tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddUpdateWindowDims( + builder, update_window_dims_vec + ) + _tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddInsertedWindowDims( + builder, inserted_window_dims_vec + ) + _tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddScatterDimsToOperandDims( + builder, scatter_dims_vec + ) + _tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddIndexVectorDim(builder, 1) + _tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddUpdateComputationSubgraphIndex(builder, 1) + scatter_opts = _tfl_stablehlo_scatter_opts.StablehloScatterOptionsEnd(builder) + + scatter_builtin = _get_stablehlo_builtin_operator("STABLEHLO_SCATTER") + reducer_builtin = _get_stablehlo_builtin_operator(reducer_name) + scatter_code = _build_operator_code(builder, scatter_builtin) + reducer_code = _build_operator_code(builder, reducer_builtin) + + main_tensors = [ + _build_tensor(builder, 0, [4]), + _build_tensor(builder, 1, [2, 1], tensor_type=_tfl_tensor_type.INT32), + _build_tensor(builder, 2, [2]), + _build_tensor(builder, 3, [4]), + ] + scatter_op = _build_operator( + builder, + 0, + [0, 1, 2], + [3], + builtin_options2_type=_tfl_builtin_options2.StablehloScatterOptions, + builtin_options2=scatter_opts, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[scatter_op], + inputs=[0, 1, 2], + outputs=[3], + ) + + body_tensors = [_build_tensor(builder, buffer_idx, []) for buffer_idx in range(4, 7)] + reducer_op = _build_operator(builder, 1, [0, 1], [2]) + body_subgraph = _build_subgraph( + builder, + tensors=body_tensors, + operators=[reducer_op], + inputs=[0, 1], + outputs=[2], + ) + + buffers = [_build_buffer(builder) for _ in range(7)] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[body_subgraph], + operator_codes=[scatter_code, reducer_code], + buffers=buffers, + ) + + +def _build_stablehlo_composite_model(with_attributes=False, use_main_input_after_composite=False): + """Build a STABLEHLO_COMPOSITE model that decomposes to STABLEHLO_NEGATE.""" + builder = flatbuffers.Builder(1024) + + name = builder.CreateString("test.negate") + attributes = None + if with_attributes: + _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsStartCompositeAttributesVector( + builder, 1 + ) + builder.PrependUint8(1) + attributes = builder.EndVector() + + _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsStart(builder) + _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddName(builder, name) + _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddVersion(builder, 1) + _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddDecompositionSubgraphIndex(builder, 1) + if attributes is not None: + _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddCompositeAttributes( + builder, attributes + ) + composite_opts = _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsEnd(builder) + + composite_builtin = _get_stablehlo_builtin_operator("STABLEHLO_COMPOSITE") + negate_builtin = _get_stablehlo_builtin_operator("STABLEHLO_NEGATE") + add_builtin = _get_stablehlo_builtin_operator("STABLEHLO_ADD") + composite_code = _build_operator_code(builder, composite_builtin) + negate_code = _build_operator_code(builder, negate_builtin) + add_code = _build_operator_code(builder, add_builtin) + + main_tensors = [ + _build_tensor(builder, 0, [2, 2]), + _build_tensor(builder, 1, [2, 2]), + _build_tensor(builder, 2, [2, 2]), + ] + composite_op = _build_operator( + builder, + 0, + [0], + [1], + builtin_options2_type=_tfl_builtin_options2.StableHLOCompositeOptions, + builtin_options2=composite_opts, + ) + main_ops = [composite_op] + main_outputs = [1] + if use_main_input_after_composite: + main_ops.append(_build_operator(builder, 2, [0, 1], [2])) + main_outputs = [2] + + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=main_ops, + inputs=[0], + outputs=main_outputs, + ) + + decomposition_tensors = [ + _build_tensor(builder, 2, [2, 2]), + _build_tensor(builder, 3, [2, 2]), + ] + negate_op = _build_operator(builder, 1, [0], [1]) + decomposition_subgraph = _build_subgraph( + builder, + tensors=decomposition_tensors, + operators=[negate_op], + inputs=[0], + outputs=[1], + ) + + buffers = [_build_buffer(builder) for _ in range(4)] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + extra_subgraphs=[decomposition_subgraph], + operator_codes=[composite_code, negate_code, add_code], + buffers=buffers, + ) + + def _build_stablehlo_typed_binary_model(*, builtin_name, tensor_type): """Build a minimal TFLite StableHLO binary model with the requested tensor type.""" builder = flatbuffers.Builder(1024) @@ -3972,19 +4443,302 @@ def test_stablehlo_binary(builtin_name, relax_op): @I.ir_module class Expected: @R.function - def main( - x: R.Tensor((2, 2), dtype="float32"), - y: R.Tensor((2, 2), dtype="float32"), - ) -> R.Tensor((2, 2), dtype="float32"): - R.func_attr({"num_input": 2}) + def main( + x: R.Tensor((2, 2), dtype="float32"), + y: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = relax_op(x, y) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_model_with_unused_subgraph(): + """TFLite StableHLO import ignores unused non-main subgraphs.""" + mod = _load_model_from_buffer(_build_stablehlo_model_with_unused_subgraph()) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 2), dtype="float32"), + y: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = R.add(x, y) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +@pytest.mark.parametrize( + "reducer_name, init_value, relax_op", + [ + ("STABLEHLO_ADD", 0.0, R.sum), + ("STABLEHLO_MAXIMUM", -np.inf, R.max), + ("STABLEHLO_MINIMUM", np.inf, R.min), + ("STABLEHLO_MULTIPLY", 1.0, R.prod), + ], +) +def test_stablehlo_reduce(reducer_name, init_value, relax_op): + """TFLite StableHLO REDUCE with simple binary reducer body subgraphs.""" + mod = _load_model_from_buffer(_build_stablehlo_reduce_model(reducer_name, init_value)) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2,), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2,), dtype="float32") = relax_op(x, axis=[1], keepdims=False) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_reduce_unsupported_reducer(): + """TFLite StableHLO REDUCE rejects unsupported body reducer ops.""" + buf = _build_stablehlo_reduce_model("STABLEHLO_SUBTRACT", 0.0) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="reducer"): + from_tflite(tflite_model) + + +def test_stablehlo_reduce_non_identity_init_unsupported(): + """TFLite StableHLO REDUCE rejects init values that Relax reductions cannot express.""" + buf = _build_stablehlo_reduce_model("STABLEHLO_ADD", 1.0) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="init value"): + from_tflite(tflite_model) + + +@pytest.mark.parametrize( + "comparison_direction, descending", + [ + ( + _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LT, + False, + ), + ( + _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_GT, + True, + ), + ], +) +def test_stablehlo_sort(comparison_direction, descending): + """TFLite StableHLO SORT with LT/GT scalar compare body subgraphs.""" + mod = _load_model_from_buffer(_build_stablehlo_sort_model(comparison_direction)) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2, 3), dtype="float32") = R.sort(x, axis=1, descending=descending) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_sort_unsupported_comparator(): + """TFLite StableHLO SORT rejects non-ordering comparators.""" + _DIR = _tfl_stablehlo_comp_dir.StablehloComparisonDirection + buf = _build_stablehlo_sort_model(_DIR.STABLEHLO_COMPARISON_DIRECTION_EQ) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="LT or GT"): + from_tflite(tflite_model) + + +def test_stablehlo_sort_stable_unsupported(): + """TFLite StableHLO SORT rejects stable sort until Relax exposes that contract.""" + _DIR = _tfl_stablehlo_comp_dir.StablehloComparisonDirection + buf = _build_stablehlo_sort_model(_DIR.STABLEHLO_COMPARISON_DIRECTION_LT, is_stable=True) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="stable sort"): + from_tflite(tflite_model) + + +def test_stablehlo_reduce_window_max_pool2d(): + """TFLite StableHLO REDUCE_WINDOW max reducer lowers to NHWC max_pool2d.""" + mod = _load_model_from_buffer(_build_stablehlo_reduce_window_model()) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((1, 4, 4, 1), dtype="float32"), + ) -> R.Tensor((1, 2, 2, 1), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((1, 2, 2, 1), dtype="float32") = R.nn.max_pool2d( + x, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_reduce_window_unsupported_reducer(): + """TFLite StableHLO REDUCE_WINDOW rejects non-max reducers in the pool subset.""" + buf = _build_stablehlo_reduce_window_model(reducer_name="STABLEHLO_ADD", init_value=0.0) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="MAXIMUM"): + from_tflite(tflite_model) + + +def test_stablehlo_reduce_window_base_dilation_unsupported(): + """TFLite StableHLO REDUCE_WINDOW rejects base dilation in the pool subset.""" + buf = _build_stablehlo_reduce_window_model(base_dilations=[1, 2, 1, 1]) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="base dilation"): + from_tflite(tflite_model) + + +@pytest.mark.parametrize( + "reducer_name, reduction", + [ + ("STABLEHLO_ADD", "add"), + ("STABLEHLO_MAXIMUM", "max"), + ("STABLEHLO_MINIMUM", "min"), + ("STABLEHLO_MULTIPLY", "mul"), + ], +) +def test_stablehlo_scatter(reducer_name, reduction): + """TFLite StableHLO SCATTER point updates lower to Relax scatter_nd.""" + mod = _load_model_from_buffer(_build_stablehlo_scatter_model(reducer_name)) + + @I.ir_module + class Expected: + @R.function + def main( + operand: R.Tensor((4,), dtype="float32"), + indices: R.Tensor((2, 1), dtype="int32"), + updates: R.Tensor((2,), dtype="float32"), + ) -> R.Tensor((4,), dtype="float32"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + gv: R.Tensor((4,), dtype="float32") = R.scatter_nd( + operand, indices, updates, reduction=reduction + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_scatter_unsupported_reducer(): + """TFLite StableHLO SCATTER rejects unsupported update computation ops.""" + buf = _build_stablehlo_scatter_model(reducer_name="STABLEHLO_SUBTRACT") + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="reducer"): + from_tflite(tflite_model) + + +def test_stablehlo_scatter_update_window_unsupported(): + """TFLite StableHLO SCATTER rejects slice update windows in the point subset.""" + buf = _build_stablehlo_scatter_model(update_window_dims=[0]) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="point updates"): + from_tflite(tflite_model) + + +def test_stablehlo_composite(): + """TFLite StableHLO COMPOSITE inlines a simple decomposition subgraph.""" + mod = _load_model_from_buffer(_build_stablehlo_composite_model()) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = R.negative(x) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_composite_does_not_overwrite_main_bindings(): + """TFLite StableHLO COMPOSITE decomposition tensor names are scoped locally.""" + mod = _load_model_from_buffer( + _build_stablehlo_composite_model(use_main_input_after_composite=True) + ) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 1}) with R.dataflow(): - gv: R.Tensor((2, 2), dtype="float32") = relax_op(x, y) + lv: R.Tensor((2, 2), dtype="float32") = R.negative(x) + gv: R.Tensor((2, 2), dtype="float32") = R.add(x, lv) R.output(gv) return gv tvm.ir.assert_structural_equal(mod, Expected) +def test_stablehlo_composite_attributes_unsupported(): + """TFLite StableHLO COMPOSITE rejects attributes until they are parsed.""" + buf = _build_stablehlo_composite_model(with_attributes=True) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="composite attributes"): + from_tflite(tflite_model) + + @pytest.mark.parametrize( "builtin_name, relax_op, dtype, tensor_type", [ @@ -4987,6 +5741,404 @@ def test_stablehlo_dynamic_slice_out_of_bounds_unsupported(): from_tflite(tflite_model) +def test_stablehlo_cbrt(): + """TFLite StableHLO CBRT uses a sign-preserving composite expression.""" + mod = _load_model_from_buffer( + _build_stablehlo_model(builtin_name="STABLEHLO_CBRT", input_count=1) + ) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.negative(x) + lv1: R.Tensor((2, 2), dtype="float32") = R.power(lv, R.const(1.0 / 3.0, "float32")) + lv2: R.Tensor((2, 2), dtype="bool") = R.less(x, R.const(0, "float32")) + lv3: R.Tensor((2, 2), dtype="float32") = R.negative(lv1) + lv4: R.Tensor((2, 2), dtype="float32") = R.power(x, R.const(1.0 / 3.0, "float32")) + gv: R.Tensor((2, 2), dtype="float32") = R.where(lv2, lv3, lv4) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_remainder(): + """TFLite StableHLO REMAINDER uses truncating remainder semantics.""" + mod = _load_model_from_buffer( + _build_stablehlo_model(builtin_name="STABLEHLO_REMAINDER", input_count=2) + ) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 2), dtype="float32"), + y: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv: R.Tensor((2, 2), dtype="float32") = R.divide(x, y) + lv1: R.Tensor((2, 2), dtype="float32") = R.trunc(lv) + lv2: R.Tensor((2, 2), dtype="float32") = R.multiply(y, lv1) + gv: R.Tensor((2, 2), dtype="float32") = R.subtract(x, lv2) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def _build_stablehlo_dynamic_update_slice_model(start_vals, dynamic_starts=False): + """Build a minimal STABLEHLO_DYNAMIC_UPDATE_SLICE model.""" + builder = flatbuffers.Builder(1024) + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DYNAMIC_UPDATE_SLICE") + op_code = _build_operator_code(builder, builtin_op) + + t_operand = _build_tensor(builder, 0, [3, 4]) + t_update = _build_tensor(builder, 1, [2, 2]) + start_tensors = [ + _build_tensor(builder, 2 + i, [], tensor_type=_tfl_tensor_type.INT32) + for i in range(len(start_vals)) + ] + out_idx = 2 + len(start_vals) + t_out = _build_tensor(builder, out_idx, [3, 4]) + tensors = [t_operand, t_update, *start_tensors, t_out] + + op_inputs = [0, 1, *range(2, out_idx)] + op = _build_operator(builder, 0, op_inputs, [out_idx]) + subgraph_inputs = op_inputs if dynamic_starts else [0, 1] + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[op], + inputs=subgraph_inputs, + outputs=[out_idx], + ) + if dynamic_starts: + buffers = [_build_buffer(builder) for _ in range(out_idx + 1)] + else: + start_buffers = [ + _build_buffer(builder, np.array([start], dtype=np.int32).tobytes()) + for start in start_vals + ] + buffers = [ + _build_buffer(builder), + _build_buffer(builder), + *start_buffers, + _build_buffer(builder), + ] + + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + + +def test_stablehlo_dynamic_update_slice(): + """TFLite StableHLO DYNAMIC_UPDATE_SLICE with static starts.""" + mod = _load_model_from_buffer(_build_stablehlo_dynamic_update_slice_model([1, 1])) + + @I.ir_module + class Expected: + @R.function + def main( + operand: R.Tensor((3, 4), dtype="float32"), + update: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((3, 4), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + gv: R.Tensor((3, 4), dtype="float32") = R.scatter_nd( + operand, + R.const([[[1, 1], [1, 2]], [[2, 1], [2, 2]]], dtype="int64"), + update, + reduction="update", + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_dynamic_update_slice_dynamic_starts_unsupported(): + """TFLite StableHLO DYNAMIC_UPDATE_SLICE with runtime starts is unsupported.""" + buf = _build_stablehlo_dynamic_update_slice_model([0, 0], dynamic_starts=True) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="dynamic start"): + from_tflite(tflite_model) + + +def test_stablehlo_dynamic_update_slice_out_of_bounds_unsupported(): + """TFLite StableHLO DYNAMIC_UPDATE_SLICE rejects out-of-bounds updates.""" + buf = _build_stablehlo_dynamic_update_slice_model([2, 3]) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="out-of-bounds"): + from_tflite(tflite_model) + + +def _build_stablehlo_dot_general_model(lhs_contract, rhs_contract, lhs_batch=None, rhs_batch=None): + """Build a minimal STABLEHLO_DOT_GENERAL model.""" + builder = flatbuffers.Builder(1024) + lhs_batch = [] if lhs_batch is None else lhs_batch + rhs_batch = [] if rhs_batch is None else rhs_batch + + lhs_batch_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartLhsBatchingDimensionsVector, + lhs_batch, + ) + rhs_batch_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartRhsBatchingDimensionsVector, + rhs_batch, + ) + lhs_contract_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartLhsContractingDimensionsVector, + lhs_contract, + ) + rhs_contract_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartRhsContractingDimensionsVector, + rhs_contract, + ) + + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStart(builder) + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddLhsBatchingDimensions( + builder, lhs_batch_vec + ) + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddRhsBatchingDimensions( + builder, rhs_batch_vec + ) + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddLhsContractingDimensions( + builder, lhs_contract_vec + ) + _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddRhsContractingDimensions( + builder, rhs_contract_vec + ) + dot_opts = _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsEnd(builder) + + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DOT_GENERAL") + op_code = _build_operator_code(builder, builtin_op) + t_lhs = _build_tensor(builder, 0, [2, 3]) + t_rhs = _build_tensor(builder, 1, [3, 4]) + t_out = _build_tensor(builder, 2, [2, 4]) + op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options2_type=_tfl_builtin_options2.StablehloDotGeneralOptions, + builtin_options2=dot_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_lhs, t_rhs, t_out], + operators=[op], + inputs=[0, 1], + outputs=[2], + ) + buffers = [_build_buffer(builder) for _ in range(3)] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + + +def test_stablehlo_dot_general(): + """TFLite StableHLO DOT_GENERAL canonical 2D matmul.""" + mod = _load_model_from_buffer(_build_stablehlo_dot_general_model([1], [0])) + + @I.ir_module + class Expected: + @R.function + def main( + lhs: R.Tensor((2, 3), dtype="float32"), + rhs: R.Tensor((3, 4), dtype="float32"), + ) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + gv: R.Tensor((2, 4), dtype="float32") = R.matmul(lhs, rhs, out_dtype="void") + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_dot_general_noncanonical_unsupported(): + """TFLite StableHLO DOT_GENERAL rejects non-canonical contracting dims.""" + buf = _build_stablehlo_dot_general_model([0], [0]) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="contracting"): + from_tflite(tflite_model) + + +def _build_stablehlo_convolution_model(feature_group_count=1, input_batch_dimension=0): + """Build a minimal STABLEHLO_CONVOLUTION model.""" + builder = flatbuffers.Builder(1024) + + window_strides_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartWindowStridesVector, + [1, 1], + ) + padding_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartPaddingVector, + [0, 0, 0, 0], + ) + lhs_dilation_vec = _tflite_int64_vector( + builder, _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartLhsDilationVector, [1, 1] + ) + rhs_dilation_vec = _tflite_int64_vector( + builder, _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartRhsDilationVector, [1, 1] + ) + window_reversal_vec = _tflite_bool_vector( + builder, + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartWindowReversalVector, + [False, False], + ) + input_spatial_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartInputSpatialDimensionsVector, + [1, 2], + ) + kernel_spatial_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartKernelSpatialDimensionsVector, + [0, 1], + ) + output_spatial_vec = _tflite_int64_vector( + builder, + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartOutputSpatialDimensionsVector, + [1, 2], + ) + + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStart(builder) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddWindowStrides( + builder, window_strides_vec + ) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddPadding(builder, padding_vec) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddLhsDilation(builder, lhs_dilation_vec) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddRhsDilation(builder, rhs_dilation_vec) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddWindowReversal( + builder, window_reversal_vec + ) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputBatchDimension( + builder, input_batch_dimension + ) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputFeatureDimension(builder, 3) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputSpatialDimensions( + builder, input_spatial_vec + ) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelInputFeatureDimension(builder, 2) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelOutputFeatureDimension(builder, 3) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelSpatialDimensions( + builder, kernel_spatial_vec + ) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputBatchDimension(builder, 0) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputFeatureDimension(builder, 3) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputSpatialDimensions( + builder, output_spatial_vec + ) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddFeatureGroupCount( + builder, feature_group_count + ) + _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddBatchGroupCount(builder, 1) + conv_opts = _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsEnd(builder) + + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_CONVOLUTION") + op_code = _build_operator_code(builder, builtin_op) + t_data = _build_tensor(builder, 0, [1, 5, 5, 2]) + t_kernel = _build_tensor(builder, 1, [3, 3, 2, 4]) + t_out = _build_tensor(builder, 2, [1, 3, 3, 4]) + op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options2_type=_tfl_builtin_options2.StablehloConvolutionOptions, + builtin_options2=conv_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_data, t_kernel, t_out], + operators=[op], + inputs=[0, 1], + outputs=[2], + ) + buffers = [_build_buffer(builder) for _ in range(3)] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + + +def test_stablehlo_convolution(): + """TFLite StableHLO CONVOLUTION canonical NHWC/HWIO 2D convolution.""" + mod = _load_model_from_buffer(_build_stablehlo_convolution_model()) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 5, 5, 2), dtype="float32"), + kernel: R.Tensor((3, 3, 2, 4), dtype="float32"), + ) -> R.Tensor((1, 3, 3, 4), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + gv: R.Tensor((1, 3, 3, 4), dtype="float32") = R.nn.conv2d( + data, + kernel, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="HWIO", + out_layout="NHWC", + out_dtype="void", + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_convolution_feature_group_unsupported(): + """TFLite StableHLO CONVOLUTION rejects grouped convolution in the first subset.""" + buf = _build_stablehlo_convolution_model(feature_group_count=2) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="feature_group_count"): + from_tflite(tflite_model) + + +def test_stablehlo_convolution_dimension_numbers_unsupported(): + """TFLite StableHLO CONVOLUTION rejects non-canonical dimension numbers.""" + buf = _build_stablehlo_convolution_model(input_batch_dimension=1) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="dimension numbers"): + from_tflite(tflite_model) + + def _build_csr_sparsity( builder, *,