Skip to content

[Bug] Numerical divergence / Token repetition in fast.scaled_dot_product_attention for 4-bit Quantized Models #3384

@Mengzhen-Zhang

Description

@Mengzhen-Zhang

When running Gemma 4 within mlx-lm (specifically continuous batching), token generation degenerates into total repetition (e.g. He He He or There There There).

After deep investigation, we tracked the underlying bug to a mathematical drift / evaluation order mismatch inside mx.fast.scaled_dot_product_attention when evaluating 4-bit quantized layers.

When the fast attention kernel is invoked with an explicit padding mask (e.g., an array of True), it follows a different evaluation graph internally compared to when it's passed mask=None. Under 4-bit quantisation, the floating-point arithmetic drift between these two paths is exceptionally severe for Gemma 4 architectures, flipping top-1 logprobs and trapping the sampler in cyclic repetition.

Impact on MLX-LM (BatchKVCache)

Because BatchKVCache.make_mask() constructs and injects an explicit boolean mx.array mask even for single-token (N=1) generation steps (whereas regular KVCache passes None), continuous multi-tenant batching causes the generation to immediately collapse. To bypass this, developers are forced to disable continuous batching (max_num_seqs=1) or implement local offset-grouped caching monkey-patches to force mask=None tensor alignment.

Reproduction

When comparing identical sequence states and prompts on an mlx_vlm quantized Gemma 4 model:

Path A (Operates Flawlessly):

mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)

Path B (Yields divergent probabilities leading to repetition):

# Even if the mask is universally True, the non-fused branching is divergent
explicit_mask = mx.ones((batch_size, 1, q_len, kv_len), dtype=mx.bool_)
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=explicit_mask)

Requested Fix

Review the underlying Metal shader / kernel paths inside mx.fast.scaled_dot_product_attention logic to ensure numerically identical (or stable) aggregation routing when explicit mask arrays are evaluated against None branching inside quantized attention blocks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions