Skip to content

[Relax][Frontend][TFLite] Support quantized TFLite import via QDQ decomposition#19538

Open
Aharrypotter wants to merge 6 commits into
apache:mainfrom
Aharrypotter:tflite_quantization_params
Open

[Relax][Frontend][TFLite] Support quantized TFLite import via QDQ decomposition#19538
Aharrypotter wants to merge 6 commits into
apache:mainfrom
Aharrypotter:tflite_quantization_params

Conversation

@Aharrypotter
Copy link
Copy Markdown
Contributor

@Aharrypotter Aharrypotter commented May 11, 2026

Summary

This PR adds initial quantized TFLite import support to the Relax frontend by
preserving tensor quantization metadata and replacing placeholder _qnn.op.*
frontend calls with an explicit QDQ decomposition:

dequantize -> float Relax op -> quantize

Before this PR, the Relax TFLite frontend raised NotImplementedError as soon
as quantization metadata was seen during tensor parsing. This made quantized
TFLite models unreachable. This PR keeps scale, zero_point, and
QuantizedDimension() in TensorWrapper.qnn_params, then uses the existing
R.quantize / R.dequantize operators to lower supported quantized paths.

Closes #19534.

Design

Relax already has R.quantize and R.dequantize with C++ registration, Python
APIs, legalization, and tests. Instead of introducing new fused Relax QNN ops
for this first import PR, the frontend now decomposes quantized TFLite operators
through QDQ around ordinary Relax float operators.

This keeps the change scoped to the Python TFLite frontend and existing Relax
QDQ operators, while establishing a working import path first. Fused int8 Relax
QNN operators can still be considered later if backend kernel selection requires
them.

Updated Converters

Converter Replacement
get_tensors Preserve scale, zero_point, and QuantizedDimension()
quantize / dequantize helpers Use R.quantize / R.dequantize with axis
convert_quantize float -> Q and quantized requantize as DQ -> Q
convert_dequantize Use R.dequantize
convert_relu, convert_relu6, convert_relu_n1_to_1 DQ -> activation -> Q
_convert_elemwise Quantized binary ops use DQ -> op -> fused activation -> Q; comparisons use DQ -> compare
convert_reshape uint8 different-qparams path uses DQ -> reshape -> Q
_convert_reduce Quantized reduce uses DQ -> reduce -> Q
convert_conv Quantized Conv2D uses DQ input + DQ weight -> conv2d -> Q
convert_fully_connected Quantized FC uses DQ input + DQ weight -> matmul -> Q
convert_concatenation Quantized concat uses DQ each -> concat -> Q
convert_transpose_conv Quantized transpose conv uses DQ input + DQ weight -> conv2d_transpose -> Q
convert_detection_postprocess Inline _qnn.op.dequantize calls replaced with self.dequantize

All _qnn.op.* references are removed, and the stale # ruff: noqa: F821
suppression is no longer needed.

Axis Remapping

The most correctness-sensitive part of this PR is axis remapping for per-channel
weight dequantization after the frontend rewrites TFLite layouts into Relax
layouts.

Op TFLite layout Relax layout Axis remap
Conv2D [OC, KH, KW, IC] [KH, KW, IC, OC] (HWIO) 0 -> 3
FullyConnected [OC, IC] [IC, OC] 0 -> 1
TransposeConv [OC, KH, KW, IC] (OHWI) [IC, OC, KH, KW] (IOHW) 0 -> 1
DepthwiseConv [1, KH, KW, C*M] [KH, KW, C, M] (HWOI) per-channel unsupported

For Conv2D, FC, and TransposeConv, non-zero weight QuantizedDimension() values
are rejected with OpAttributeInvalid, because the supported quantized TFLite
weight layout uses output-channel axis 0.

Per-channel depthwise convolution is guarded with OpNotImplemented. The
TFLite depthwise reshape changes the channel-axis semantics in a way that this
initial QDQ lowering does not represent directly.

Bias Handling

TFLite INT32/INT64 bias tensors may not store explicit quantization metadata.
For quantized Conv2D, FullyConnected, and TransposeConv, the frontend follows
the implicit TFLite convention and dequantizes integer bias using:

bias_scale = input_scale * weight_scale
bias_zero_point = 0
axis = 0

This supports both per-tensor and per-channel weight scales. The per-channel
case is covered by a structural regression test that expects vector bias scale.

Fused Activation Handling

Most quantized convolution-like paths preserve the existing quantized-domain
fused activation behavior:

float op -> Q -> quantized-domain clip

The elemwise QDQ path applies fused activation before the final quantize:

DQ -> float binary op -> float fused activation -> Q

Both paths are intentional and covered by regression tests:

  • quantized concat fused RELU checks the quantized-domain clip path
  • quantized add fused RELU6 checks the float-domain activation-before-Q path

This PR also fixes a latent R.clip call-site bug in the quantized fused
RELU helper by using max= rather than the unsupported a_max= keyword.

Safety Checks

  • Quantized elemwise non-comparison outputs must have output qparams. Missing
    output quantization metadata now raises OpAttributeInvalid instead of
    silently returning a float result.
  • Per-channel quantization rejects non-zero per-axis zero points, following the
    TFLite quantization specification.
  • Per-channel depthwise convolution is explicitly unsupported rather than
    importing with an incorrect axis interpretation.

Tests

The new tests build minimal TFLite flatbuffers directly and compare the imported
Relax IR with tvm.ir.assert_structural_equal. Unsupported-boundary tests use
pytest.raises.

Test Coverage
test_tensor_quantization_parameters_are_parsed per-tensor and per-axis metadata parsing
test_quantize_op_uses_relax_quantize TFLite QUANTIZE float input
test_quantize_op_requantize_uses_dq_q TFLite QUANTIZE as requantize
test_dequantize_op_uses_relax_dequantize TFLite DEQUANTIZE
test_quantized_add_uses_qdq quantized ADD with differing input qparams
test_quantized_add_fused_relu6_uses_float_clip_before_quantize elemwise fused activation before Q
test_quantized_add_without_output_qparams_invalid invalid missing output qparams guard
test_quantized_conv2d_per_tensor_uses_qdq Conv2D per-tensor QDQ
test_quantized_conv2d_per_channel_weight_uses_remapped_axis Conv2D per-channel weight axis 0 -> 3
test_quantized_conv2d_with_int32_bias_dequantizes_bias Conv2D INT32 bias scale
test_quantized_conv2d_per_channel_weight_with_int32_bias_dequantizes_bias Conv2D per-channel vector bias scale
test_quantized_concat_uses_qdq concat QDQ path
test_quantized_concat_fused_relu_uses_quantized_clip quantized-domain fused RELU clip
test_per_channel_depthwise_conv_unsupported per-channel depthwise guard
test_uint8_reshape_requantize_uses_dq_reshape_q uint8 reshape with different qparams
test_transpose_conv_with_int32_bias_dequantizes_bias TransposeConv INT32 bias DQ
test_quantized_fully_connected_with_int32_bias_dequantizes_bias FC INT32 bias DQ

Local validation:

python -m ruff format --check \
  python/tvm/relax/frontend/tflite/tflite_frontend.py \
  tests/python/relax/test_frontend_tflite.py

python -m ruff check \
  python/tvm/relax/frontend/tflite/tflite_frontend.py \
  tests/python/relax/test_frontend_tflite.py

python -m pytest tests/python/relax/test_frontend_tflite.py -q

Result:

433 passed

Limitations

  • This PR prioritizes correct import and explicit Relax IR over fused int8
    kernel selection. The generated IR uses QDQ and float Relax operators.
  • Per-channel depthwise convolution remains unsupported.
  • The tests are structural IR tests. Numerical comparison against TFLite runtime
    outputs is left to follow-up work.

References

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for quantized TFLite models in the Relax frontend by transitioning from specialized QNN operators to a dequantize-compute-quantize pattern using core Relax ops. Key changes include the addition of per-channel quantization support and updates to operators like convolution, fully connected, and activations. Feedback highlights potential runtime errors when handling per-channel quantized models due to the use of scalar-only constant extraction for scales; it is recommended to use relax.op.multiply to combine input and weight scales. Additionally, the reviewer suggests adding explicit validation for quantization axes in convolution and fully connected layers to ensure robustness against non-standard TFLite models.

Comment on lines +2554 to +2563
bias_scale_val = (
get_scalar_from_constant(input_tensor.qnn_params["scale"])
* get_scalar_from_constant(weight_tensor.qnn_params["scale"])
)
bias_expr = relax.op.dequantize(
bias_expr,
scale=relax.const(bias_scale_val, "float32"),
zero_point=relax.const(0, "int32"),
axis=0,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using get_scalar_from_constant on weight_tensor.qnn_params["scale"] will cause an AssertionError for per-channel quantized models because the scale is a constant tensor, not a scalar. To support both per-tensor and per-channel quantization robustly, you should use relax.op.multiply on the scale expressions and pass the resulting expression directly to relax.op.dequantize.

Suggested change
bias_scale_val = (
get_scalar_from_constant(input_tensor.qnn_params["scale"])
* get_scalar_from_constant(weight_tensor.qnn_params["scale"])
)
bias_expr = relax.op.dequantize(
bias_expr,
scale=relax.const(bias_scale_val, "float32"),
zero_point=relax.const(0, "int32"),
axis=0,
)
elif input_tensor.qnn_params and bias_tensor_type in (
TensorType.INT32,
TensorType.INT64,
):
bias_scale = relax.op.multiply(
input_tensor.qnn_params["scale"],
weight_tensor.qnn_params["scale"],
)
bias_expr = relax.op.dequantize(
bias_expr,
scale=bias_scale,
zero_point=relax.const(0, "int32"),
axis=0,
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — replaced get_scalar_from_constant with relax.op.multiply so the bias scale works for both per-tensor and per-channel weight quantization. The multiply result is passed directly to relax.op.dequantize as a Relax expression.

Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py
Comment on lines +4384 to +4393
bias_scale_val = (
get_scalar_from_constant(input_tensor.qnn_params["scale"])
* get_scalar_from_constant(weights_tensor.qnn_params["scale"])
)
bias_expr = relax.op.dequantize(
bias_expr,
scale=relax.const(bias_scale_val, "float32"),
zero_point=relax.const(0, "int32"),
axis=0,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This get_scalar_from_constant call will crash on per-channel quantized TransposeConv models. Use relax.op.multiply to combine the input and weight scales into a single expression for the bias dequantization.

                bias_expr = relax.op.dequantize(
                    bias_expr,
                    scale=relax.op.multiply(
                        input_tensor.qnn_params["scale"],
                        weights_tensor.qnn_params["scale"],
                    ),
                    zero_point=relax.const(0, "int32"),
                    axis=0,
                )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — replaced get_scalar_from_constant with relax.op.multiply so the bias scale works for both per-tensor and per-channel weight quantization. The multiply result is passed directly to relax.op.dequantize as a Relax expression.

Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py
Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py
Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py
@Aharrypotter Aharrypotter force-pushed the tflite_quantization_params branch 3 times, most recently from 37e24ca to 88c8630 Compare May 15, 2026 02:15
@Aharrypotter Aharrypotter force-pushed the tflite_quantization_params branch from 88c8630 to f5075ec Compare May 19, 2026 09:01
@Aharrypotter
Copy link
Copy Markdown
Contributor Author

cc @tlopex

@tlopex
Copy link
Copy Markdown
Member

tlopex commented May 21, 2026

Please resolve the conflict

Remove the global NotImplementedError guard in get_tensors() that blocked
all quantized TFLite models at the tensor-parsing stage.  The guard
prevented the frontend from advancing to operator conversion even when
only tensor-level metadata was needed.

Changes:
- Preserve scale and zero_point as before (per-tensor and per-axis)
- Additionally record axis = QuantizedDimension() in qnn_params
- Remove the global guard; errors now surface at specific operator
  converters rather than at tensor metadata parsing
- Update the F821 lint comment to reflect the new state

Test: add test_tensor_quantization_parameters_are_parsed which builds
a minimal TFLite flatbuffer with per-tensor and per-axis quantization
and verifies that TensorWrapper.qnn_params contains scale, zero_point,
and axis.  Assert that from_tflite() no longer fails at tensor parsing.

This is the first milestone of apache#19534 (quantized TFLite import).
Subsequent PRs will replace _qnn.op.* with Relax QDQ ops and add
quantized operator conversion.
…ntize

Replace the quantize() and dequantize() frontend helpers, which
previously referenced non-existent _qnn.op.quantize / _qnn.op.dequantize,
with the existing relax.op.quantize / relax.op.dequantize operators.

Changes:
- quantize(): _qnn.op.quantize -> relax.op.quantize, add axis param
- dequantize(): _qnn.op.dequantize -> relax.op.dequantize, add axis param
- Update F821 lint comment to enumerate remaining _qnn references

Tests:
- test_quantize_op_uses_relax_quantize: builds a minimal TFLite
  flatbuffer with QUANTIZE (float32 -> int8) and asserts the IR
  uses R.quantize with scale, zero_point, axis, and out_dtype
- test_dequantize_op_uses_relax_dequantize: builds a minimal TFLite
  flatbuffer with DEQUANTIZE (int8 -> float32) and asserts the IR
  uses R.dequantize with scale, zero_point, and axis

Part of apache#19534 (quantized TFLite import).  Subsequent PRs will
handle requantize, Conv2D, Dense, and remaining quantized ops.
Replace the _qnn.op.conv2d and _qnn.op.requantize calls in convert_conv()
with a DQ → float conv2d → Q flow using the existing relax.op.dequantize
and relax.op.quantize operators that were wired in PR apache#2.

Changes to convert_conv():
- Dequantize input activation and weight before the float conv2d.
- Remap per-channel weight QuantizedDimension() from the original
  TFLite layout (OC=0) to the HWIO layout (OC=3) for the dequantize axis.
- Dequantize INT32/INT64 bias before adding to the float conv output.
- Replace the fused _qnn.op.requantize + activation call with
  self.quantize() + convert_qnn_fused_activation_function().

Test: test_quantized_conv2d_per_tensor_uses_qdq builds a minimal
TFLite flatbuffer with a per-tensor quantized Conv2D and asserts
the IR uses dequantize → permute_dims → dequantize → conv2d →
quantize.

Known limitations (will be addressed in follow-ups):
- DepthwiseConv2D axis remap not yet handled.
- Per-channel weight test not yet added.
- INT32 bias dequantization uses input_scale only (not input_scale ×
  per-channel weight_scale).

Part of apache#19534 (quantized TFLite import).
…le ops

Replace the remaining _qnn.op.requantize calls in elementwise and
reshape/reduce converters with the DQ → float op → Q pattern using
the existing relax.op.dequantize / relax.op.quantize operators.

Converters updated:
- convert_relu: QNN fused RELU + requantize → DQ → relu → Q
- convert_relu6: QNN fused RELU6 + requantize → DQ → clip → Q
- convert_relu_n1_to_1: quantized clip + requantize → DQ → clip → Q
- convert_reshape: uint8 requantize → self.quantize
- _convert_reduce: int32 cast + requantize → DQ → op → Q
  (covers multinomial and all reduce-like ops)

Part of apache#19534 (quantized TFLite import).
…ized ops

Replace the last _qnn.op.* references in the TFLite frontend with
the DQ → float op → Q pattern, eliminating all references to the
non-existent _qnn module.

convert_fully_connected:
- _qnn.op.dense → DQ input + DQ weight (axis remap OC 0→1) + matmul
- _qnn.op.requantize + activation → self.quantize + activation
- INT32/INT64 bias dequantized with input_scale × weight_scale

convert_concatenation:
- _qnn.op.concat → DQ each input → float concat → quantize → activation

convert_transpose_conv:
- _qnn.op.conv2d_transpose → DQ input + DQ weight (axis remap OHWI→IOHW,
  OC axis 0→1) + float conv2d_transpose
- _qnn.op.requantize → self.quantize
- INT32/INT64 bias dequantized (previously missing — added in review fix)

convert_detection_postprocess:
- 3× _qnn.op.dequantize → self.dequantize

convert_reshape (uint8 path):
- Requantize on integer tensor → DQ → reshape → Q

Depthwise Conv2D:
- Explicit OpNotImplemented for per-channel depthwise (axis semantics
  change after [1,KH,KW,C*M] → [KH,KW,C,M] reshape)

Cleanup:
- Removed now-unnecessary F821 noqa comment (zero _qnn / _expr refs)
- Removed unused locals (weight_shape, output_tensor_type_str)

All _qnn.op.* references eliminated.  386 tests pass, ruff clean.
Closes apache#19534.
@Aharrypotter Aharrypotter force-pushed the tflite_quantization_params branch 3 times, most recently from 94b8772 to e0bea25 Compare May 22, 2026 16:56
@Aharrypotter
Copy link
Copy Markdown
Contributor Author

Aharrypotter commented May 22, 2026

Rebased the branch and resolved the merge conflicts.

During the post-rebase check, I found a few related issues and fixed them in the same update:

  • Avoid generated tflite FlatBuffer builder APIs in tests, since they are not available in CI.
  • Make the quantized elemwise path follow the same DQ -> float op -> Q pattern used by the other quantized converters.
  • Add regression tests for quantized elemwise ops, fused activation handling, and per-channel Conv2D bias / axis remapping.

These changes are closely related to the rebase and should help keep the PR passing.

@Aharrypotter
Copy link
Copy Markdown
Contributor Author

Update PR description

Use _get_tflite_schema_module() for TFLite builtin option builder helpers instead of accessing them directly from the tflite module. CI's tflite package does not reliably re-export these schema sub-module functions at the top level.

Also fix ruff lint issues (RUF002 ambiguous unicode, E501 long lines, I001 import sorting) and apply ruff format.
@Aharrypotter Aharrypotter force-pushed the tflite_quantization_params branch from e0bea25 to 95667ca Compare May 22, 2026 17:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Tracking Issue][TFLite] Support quantized operator import in Relax frontend

2 participants