When running Gemma 4 within mlx-lm (specifically continuous batching), token generation degenerates into total repetition (e.g. He He He or There There There).
After deep investigation, we tracked the underlying bug to a mathematical drift / evaluation order mismatch inside mx.fast.scaled_dot_product_attention when evaluating 4-bit quantized layers.
When the fast attention kernel is invoked with an explicit padding mask (e.g., an array of True), it follows a different evaluation graph internally compared to when it's passed mask=None. Under 4-bit quantisation, the floating-point arithmetic drift between these two paths is exceptionally severe for Gemma 4 architectures, flipping top-1 logprobs and trapping the sampler in cyclic repetition.
Impact on MLX-LM (BatchKVCache)
Because BatchKVCache.make_mask() constructs and injects an explicit boolean mx.array mask even for single-token (N=1) generation steps (whereas regular KVCache passes None), continuous multi-tenant batching causes the generation to immediately collapse. To bypass this, developers are forced to disable continuous batching (max_num_seqs=1) or implement local offset-grouped caching monkey-patches to force mask=None tensor alignment.
Reproduction
When comparing identical sequence states and prompts on an mlx_vlm quantized Gemma 4 model:
Path A (Operates Flawlessly):
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
Path B (Yields divergent probabilities leading to repetition):
# Even if the mask is universally True, the non-fused branching is divergent
explicit_mask = mx.ones((batch_size, 1, q_len, kv_len), dtype=mx.bool_)
mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=explicit_mask)
Requested Fix
Review the underlying Metal shader / kernel paths inside mx.fast.scaled_dot_product_attention logic to ensure numerically identical (or stable) aggregation routing when explicit mask arrays are evaluated against None branching inside quantized attention blocks.
When running Gemma 4 within
mlx-lm(specifically continuous batching), token generation degenerates into total repetition (e.g.He He HeorThere There There).After deep investigation, we tracked the underlying bug to a mathematical drift / evaluation order mismatch inside
mx.fast.scaled_dot_product_attentionwhen evaluating 4-bit quantized layers.When the fast attention kernel is invoked with an explicit padding mask (e.g., an array of
True), it follows a different evaluation graph internally compared to when it's passedmask=None. Under 4-bit quantisation, the floating-point arithmetic drift between these two paths is exceptionally severe for Gemma 4 architectures, flipping top-1 logprobs and trapping the sampler in cyclic repetition.Impact on MLX-LM (
BatchKVCache)Because
BatchKVCache.make_mask()constructs and injects an explicit booleanmx.arraymask even for single-token (N=1) generation steps (whereas regularKVCachepassesNone), continuous multi-tenant batching causes the generation to immediately collapse. To bypass this, developers are forced to disable continuous batching (max_num_seqs=1) or implement local offset-grouped caching monkey-patches to forcemask=Nonetensor alignment.Reproduction
When comparing identical sequence states and prompts on an
mlx_vlmquantized Gemma 4 model:Path A (Operates Flawlessly):
Path B (Yields divergent probabilities leading to repetition):
Requested Fix
Review the underlying Metal shader / kernel paths inside
mx.fast.scaled_dot_product_attentionlogic to ensure numerically identical (or stable) aggregation routing when explicit mask arrays are evaluated againstNonebranching inside quantized attention blocks.