[Relax][Frontend][TFLite] Support quantized TFLite import via QDQ decomposition#19538
[Relax][Frontend][TFLite] Support quantized TFLite import via QDQ decomposition#19538Aharrypotter wants to merge 6 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements support for quantized TFLite models in the Relax frontend by transitioning from specialized QNN operators to a dequantize-compute-quantize pattern using core Relax ops. Key changes include the addition of per-channel quantization support and updates to operators like convolution, fully connected, and activations. Feedback highlights potential runtime errors when handling per-channel quantized models due to the use of scalar-only constant extraction for scales; it is recommended to use relax.op.multiply to combine input and weight scales. Additionally, the reviewer suggests adding explicit validation for quantization axes in convolution and fully connected layers to ensure robustness against non-standard TFLite models.
| bias_scale_val = ( | ||
| get_scalar_from_constant(input_tensor.qnn_params["scale"]) | ||
| * get_scalar_from_constant(weight_tensor.qnn_params["scale"]) | ||
| ) | ||
| bias_expr = relax.op.dequantize( | ||
| bias_expr, | ||
| scale=relax.const(bias_scale_val, "float32"), | ||
| zero_point=relax.const(0, "int32"), | ||
| axis=0, | ||
| ) |
There was a problem hiding this comment.
Using get_scalar_from_constant on weight_tensor.qnn_params["scale"] will cause an AssertionError for per-channel quantized models because the scale is a constant tensor, not a scalar. To support both per-tensor and per-channel quantization robustly, you should use relax.op.multiply on the scale expressions and pass the resulting expression directly to relax.op.dequantize.
| bias_scale_val = ( | |
| get_scalar_from_constant(input_tensor.qnn_params["scale"]) | |
| * get_scalar_from_constant(weight_tensor.qnn_params["scale"]) | |
| ) | |
| bias_expr = relax.op.dequantize( | |
| bias_expr, | |
| scale=relax.const(bias_scale_val, "float32"), | |
| zero_point=relax.const(0, "int32"), | |
| axis=0, | |
| ) | |
| elif input_tensor.qnn_params and bias_tensor_type in ( | |
| TensorType.INT32, | |
| TensorType.INT64, | |
| ): | |
| bias_scale = relax.op.multiply( | |
| input_tensor.qnn_params["scale"], | |
| weight_tensor.qnn_params["scale"], | |
| ) | |
| bias_expr = relax.op.dequantize( | |
| bias_expr, | |
| scale=bias_scale, | |
| zero_point=relax.const(0, "int32"), | |
| axis=0, | |
| ) |
There was a problem hiding this comment.
Done — replaced get_scalar_from_constant with relax.op.multiply so the bias scale works for both per-tensor and per-channel weight quantization. The multiply result is passed directly to relax.op.dequantize as a Relax expression.
| bias_scale_val = ( | ||
| get_scalar_from_constant(input_tensor.qnn_params["scale"]) | ||
| * get_scalar_from_constant(weights_tensor.qnn_params["scale"]) | ||
| ) | ||
| bias_expr = relax.op.dequantize( | ||
| bias_expr, | ||
| scale=relax.const(bias_scale_val, "float32"), | ||
| zero_point=relax.const(0, "int32"), | ||
| axis=0, | ||
| ) |
There was a problem hiding this comment.
This get_scalar_from_constant call will crash on per-channel quantized TransposeConv models. Use relax.op.multiply to combine the input and weight scales into a single expression for the bias dequantization.
bias_expr = relax.op.dequantize(
bias_expr,
scale=relax.op.multiply(
input_tensor.qnn_params["scale"],
weights_tensor.qnn_params["scale"],
),
zero_point=relax.const(0, "int32"),
axis=0,
)There was a problem hiding this comment.
Done — replaced get_scalar_from_constant with relax.op.multiply so the bias scale works for both per-tensor and per-channel weight quantization. The multiply result is passed directly to relax.op.dequantize as a Relax expression.
37e24ca to
88c8630
Compare
88c8630 to
f5075ec
Compare
|
cc @tlopex |
|
Please resolve the conflict |
Remove the global NotImplementedError guard in get_tensors() that blocked all quantized TFLite models at the tensor-parsing stage. The guard prevented the frontend from advancing to operator conversion even when only tensor-level metadata was needed. Changes: - Preserve scale and zero_point as before (per-tensor and per-axis) - Additionally record axis = QuantizedDimension() in qnn_params - Remove the global guard; errors now surface at specific operator converters rather than at tensor metadata parsing - Update the F821 lint comment to reflect the new state Test: add test_tensor_quantization_parameters_are_parsed which builds a minimal TFLite flatbuffer with per-tensor and per-axis quantization and verifies that TensorWrapper.qnn_params contains scale, zero_point, and axis. Assert that from_tflite() no longer fails at tensor parsing. This is the first milestone of apache#19534 (quantized TFLite import). Subsequent PRs will replace _qnn.op.* with Relax QDQ ops and add quantized operator conversion.
…ntize Replace the quantize() and dequantize() frontend helpers, which previously referenced non-existent _qnn.op.quantize / _qnn.op.dequantize, with the existing relax.op.quantize / relax.op.dequantize operators. Changes: - quantize(): _qnn.op.quantize -> relax.op.quantize, add axis param - dequantize(): _qnn.op.dequantize -> relax.op.dequantize, add axis param - Update F821 lint comment to enumerate remaining _qnn references Tests: - test_quantize_op_uses_relax_quantize: builds a minimal TFLite flatbuffer with QUANTIZE (float32 -> int8) and asserts the IR uses R.quantize with scale, zero_point, axis, and out_dtype - test_dequantize_op_uses_relax_dequantize: builds a minimal TFLite flatbuffer with DEQUANTIZE (int8 -> float32) and asserts the IR uses R.dequantize with scale, zero_point, and axis Part of apache#19534 (quantized TFLite import). Subsequent PRs will handle requantize, Conv2D, Dense, and remaining quantized ops.
Replace the _qnn.op.conv2d and _qnn.op.requantize calls in convert_conv() with a DQ → float conv2d → Q flow using the existing relax.op.dequantize and relax.op.quantize operators that were wired in PR apache#2. Changes to convert_conv(): - Dequantize input activation and weight before the float conv2d. - Remap per-channel weight QuantizedDimension() from the original TFLite layout (OC=0) to the HWIO layout (OC=3) for the dequantize axis. - Dequantize INT32/INT64 bias before adding to the float conv output. - Replace the fused _qnn.op.requantize + activation call with self.quantize() + convert_qnn_fused_activation_function(). Test: test_quantized_conv2d_per_tensor_uses_qdq builds a minimal TFLite flatbuffer with a per-tensor quantized Conv2D and asserts the IR uses dequantize → permute_dims → dequantize → conv2d → quantize. Known limitations (will be addressed in follow-ups): - DepthwiseConv2D axis remap not yet handled. - Per-channel weight test not yet added. - INT32 bias dequantization uses input_scale only (not input_scale × per-channel weight_scale). Part of apache#19534 (quantized TFLite import).
…le ops Replace the remaining _qnn.op.requantize calls in elementwise and reshape/reduce converters with the DQ → float op → Q pattern using the existing relax.op.dequantize / relax.op.quantize operators. Converters updated: - convert_relu: QNN fused RELU + requantize → DQ → relu → Q - convert_relu6: QNN fused RELU6 + requantize → DQ → clip → Q - convert_relu_n1_to_1: quantized clip + requantize → DQ → clip → Q - convert_reshape: uint8 requantize → self.quantize - _convert_reduce: int32 cast + requantize → DQ → op → Q (covers multinomial and all reduce-like ops) Part of apache#19534 (quantized TFLite import).
…ized ops Replace the last _qnn.op.* references in the TFLite frontend with the DQ → float op → Q pattern, eliminating all references to the non-existent _qnn module. convert_fully_connected: - _qnn.op.dense → DQ input + DQ weight (axis remap OC 0→1) + matmul - _qnn.op.requantize + activation → self.quantize + activation - INT32/INT64 bias dequantized with input_scale × weight_scale convert_concatenation: - _qnn.op.concat → DQ each input → float concat → quantize → activation convert_transpose_conv: - _qnn.op.conv2d_transpose → DQ input + DQ weight (axis remap OHWI→IOHW, OC axis 0→1) + float conv2d_transpose - _qnn.op.requantize → self.quantize - INT32/INT64 bias dequantized (previously missing — added in review fix) convert_detection_postprocess: - 3× _qnn.op.dequantize → self.dequantize convert_reshape (uint8 path): - Requantize on integer tensor → DQ → reshape → Q Depthwise Conv2D: - Explicit OpNotImplemented for per-channel depthwise (axis semantics change after [1,KH,KW,C*M] → [KH,KW,C,M] reshape) Cleanup: - Removed now-unnecessary F821 noqa comment (zero _qnn / _expr refs) - Removed unused locals (weight_shape, output_tensor_type_str) All _qnn.op.* references eliminated. 386 tests pass, ruff clean. Closes apache#19534.
94b8772 to
e0bea25
Compare
|
Rebased the branch and resolved the merge conflicts. During the post-rebase check, I found a few related issues and fixed them in the same update:
These changes are closely related to the rebase and should help keep the PR passing. |
|
Update PR description |
Use _get_tflite_schema_module() for TFLite builtin option builder helpers instead of accessing them directly from the tflite module. CI's tflite package does not reliably re-export these schema sub-module functions at the top level. Also fix ruff lint issues (RUF002 ambiguous unicode, E501 long lines, I001 import sorting) and apply ruff format.
e0bea25 to
95667ca
Compare
Summary
This PR adds initial quantized TFLite import support to the Relax frontend by
preserving tensor quantization metadata and replacing placeholder
_qnn.op.*frontend calls with an explicit QDQ decomposition:
Before this PR, the Relax TFLite frontend raised
NotImplementedErroras soonas quantization metadata was seen during tensor parsing. This made quantized
TFLite models unreachable. This PR keeps
scale,zero_point, andQuantizedDimension()inTensorWrapper.qnn_params, then uses the existingR.quantize/R.dequantizeoperators to lower supported quantized paths.Closes #19534.
Design
Relax already has
R.quantizeandR.dequantizewith C++ registration, PythonAPIs, legalization, and tests. Instead of introducing new fused Relax QNN ops
for this first import PR, the frontend now decomposes quantized TFLite operators
through QDQ around ordinary Relax float operators.
This keeps the change scoped to the Python TFLite frontend and existing Relax
QDQ operators, while establishing a working import path first. Fused int8 Relax
QNN operators can still be considered later if backend kernel selection requires
them.
Updated Converters
get_tensorsscale,zero_point, andQuantizedDimension()quantize/dequantizehelpersR.quantize/R.dequantizewithaxisconvert_quantizefloat -> Qand quantized requantize asDQ -> Qconvert_dequantizeR.dequantizeconvert_relu,convert_relu6,convert_relu_n1_to_1DQ -> activation -> Q_convert_elemwiseDQ -> op -> fused activation -> Q; comparisons useDQ -> compareconvert_reshapeDQ -> reshape -> Q_convert_reduceDQ -> reduce -> Qconvert_convDQ input + DQ weight -> conv2d -> Qconvert_fully_connectedDQ input + DQ weight -> matmul -> Qconvert_concatenationDQ each -> concat -> Qconvert_transpose_convDQ input + DQ weight -> conv2d_transpose -> Qconvert_detection_postprocess_qnn.op.dequantizecalls replaced withself.dequantizeAll
_qnn.op.*references are removed, and the stale# ruff: noqa: F821suppression is no longer needed.
Axis Remapping
The most correctness-sensitive part of this PR is axis remapping for per-channel
weight dequantization after the frontend rewrites TFLite layouts into Relax
layouts.
[OC, KH, KW, IC][KH, KW, IC, OC](HWIO)0 -> 3[OC, IC][IC, OC]0 -> 1[OC, KH, KW, IC](OHWI)[IC, OC, KH, KW](IOHW)0 -> 1[1, KH, KW, C*M][KH, KW, C, M](HWOI)For Conv2D, FC, and TransposeConv, non-zero weight
QuantizedDimension()valuesare rejected with
OpAttributeInvalid, because the supported quantized TFLiteweight layout uses output-channel axis 0.
Per-channel depthwise convolution is guarded with
OpNotImplemented. TheTFLite depthwise reshape changes the channel-axis semantics in a way that this
initial QDQ lowering does not represent directly.
Bias Handling
TFLite INT32/INT64 bias tensors may not store explicit quantization metadata.
For quantized Conv2D, FullyConnected, and TransposeConv, the frontend follows
the implicit TFLite convention and dequantizes integer bias using:
This supports both per-tensor and per-channel weight scales. The per-channel
case is covered by a structural regression test that expects vector bias scale.
Fused Activation Handling
Most quantized convolution-like paths preserve the existing quantized-domain
fused activation behavior:
The elemwise QDQ path applies fused activation before the final quantize:
Both paths are intentional and covered by regression tests:
RELUchecks the quantized-domain clip pathRELU6checks the float-domain activation-before-Q pathThis PR also fixes a latent
R.clipcall-site bug in the quantized fusedRELUhelper by usingmax=rather than the unsupporteda_max=keyword.Safety Checks
output quantization metadata now raises
OpAttributeInvalidinstead ofsilently returning a float result.
TFLite quantization specification.
importing with an incorrect axis interpretation.
Tests
The new tests build minimal TFLite flatbuffers directly and compare the imported
Relax IR with
tvm.ir.assert_structural_equal. Unsupported-boundary tests usepytest.raises.test_tensor_quantization_parameters_are_parsedtest_quantize_op_uses_relax_quantizeQUANTIZEfloat inputtest_quantize_op_requantize_uses_dq_qQUANTIZEas requantizetest_dequantize_op_uses_relax_dequantizeDEQUANTIZEtest_quantized_add_uses_qdqtest_quantized_add_fused_relu6_uses_float_clip_before_quantizetest_quantized_add_without_output_qparams_invalidtest_quantized_conv2d_per_tensor_uses_qdqtest_quantized_conv2d_per_channel_weight_uses_remapped_axis0 -> 3test_quantized_conv2d_with_int32_bias_dequantizes_biastest_quantized_conv2d_per_channel_weight_with_int32_bias_dequantizes_biastest_quantized_concat_uses_qdqtest_quantized_concat_fused_relu_uses_quantized_cliptest_per_channel_depthwise_conv_unsupportedtest_uint8_reshape_requantize_uses_dq_reshape_qtest_transpose_conv_with_int32_bias_dequantizes_biastest_quantized_fully_connected_with_int32_bias_dequantizes_biasLocal validation:
Result:
Limitations
kernel selection. The generated IR uses QDQ and float Relax operators.
outputs is left to follow-up work.
References