Summary
mx.fast.scaled_dot_product_attention hangs indefinitely or crashes with SIGSEGV when called under mx.vmap with grouped-query attention shapes where n_heads != n_kv_heads. Both failure modes are observed across runs with the same inputs. The same shapes work correctly without vmap, and equal-head (MHA) shapes work correctly with vmap.
Reproducer
import mlx.core as mx
def f(qi, ki, vi):
return mx.mean(mx.fast.scaled_dot_product_attention(
qi[None], ki[None], vi[None], scale=0.125))
# GQA: 4 query heads, 2 KV heads
H_q, H_kv, L, D = 4, 2, 4, 64
q = mx.random.normal((2, H_q, L, D))
k = mx.random.normal((2, H_kv, L, D))
v = mx.random.normal((2, H_kv, L, D))
mx.eval(q, k, v)
# Hangs indefinitely or crashes with SIGSEGV:
out = mx.vmap(f)(q, k, v)
mx.eval(out)
Changing H_kv to 4 (MHA) makes it pass instantly.
Diagnostic matrix
All tests use the same structure, varying only head counts and transform composition:
| H_q |
H_kv |
Transform |
Result |
| 4 |
4 |
vmap(fwd) |
PASS (0.01s) |
| 4 |
4 |
vmap(grad) |
PASS (0.01s) |
| 4 |
2 |
grad |
PASS (instant) |
| 4 |
2 |
vmap(fwd) |
HANG / SIGSEGV |
| 4 |
2 |
vmap(grad) |
HANG |
| 4 |
1 |
vmap(fwd) |
HANG |
The boundary is precisely n_heads != n_kv_heads under vmap. Equal-head configurations pass. Non-vmap'd grad passes even with GQA shapes. The bug is in the ScaledDotProductAttention primitive's vmap rule, not in the grad composition.
Workaround
Replacing the fused SDPA with decomposed matmul → softmax → matmul (with explicit mx.repeat for KV head expansion) works correctly under vmap and vmap(grad) for all head configurations.
Impact
This blocks any vmap composition involving SDPA on GQA/MQA models. GQA is used in essentially all recent open-weight LLMs (Qwen 2/2.5/3, Llama 3/3.1/3.2/3.3, Mistral, Phi-3, Gemma 2). Any use case requiring per-sample operations through attention (per-sample gradients, per-sample Jacobians, batched inference with varying parameters) is affected.
Environment
- MLX: 0.31.1 (latest)
- Python: 3.12.4
- macOS: 26.3.1 (Tahoe)
- Chip: Apple M1 Pro (16GB)
Summary
mx.fast.scaled_dot_product_attentionhangs indefinitely or crashes with SIGSEGV when called undermx.vmapwith grouped-query attention shapes wheren_heads != n_kv_heads. Both failure modes are observed across runs with the same inputs. The same shapes work correctly withoutvmap, and equal-head (MHA) shapes work correctly withvmap.Reproducer
Changing
H_kvto4(MHA) makes it pass instantly.Diagnostic matrix
All tests use the same structure, varying only head counts and transform composition:
vmap(fwd)vmap(grad)gradvmap(fwd)vmap(grad)vmap(fwd)The boundary is precisely
n_heads != n_kv_headsundervmap. Equal-head configurations pass. Non-vmap'dgradpasses even with GQA shapes. The bug is in theScaledDotProductAttentionprimitive's vmap rule, not in the grad composition.Workaround
Replacing the fused SDPA with decomposed
matmul → softmax → matmul(with explicitmx.repeatfor KV head expansion) works correctly undervmapandvmap(grad)for all head configurations.Impact
This blocks any
vmapcomposition involving SDPA on GQA/MQA models. GQA is used in essentially all recent open-weight LLMs (Qwen 2/2.5/3, Llama 3/3.1/3.2/3.3, Mistral, Phi-3, Gemma 2). Any use case requiring per-sample operations through attention (per-sample gradients, per-sample Jacobians, batched inference with varying parameters) is affected.Environment