@@ -489,7 +489,7 @@ HWY_INLINE float SingleFlashAttentionRowVector(DF df, size_t start_pos,
489489 }
490490 float m = hn::ReduceMax (df, x);
491491 m = std::max (m, old_max);
492- x = hn::Exp (df, hn::Sub (x, hn::Set (df, m)));
492+ x = hn::FastExpMinusOrZero (df, hn::Sub (x, hn::Set (df, m)));
493493 float scale = old_d * std::exp (old_max - m);
494494 old_d = hn::ReduceSum (df, x) + scale;
495495 old_max = m;
@@ -538,8 +538,8 @@ HWY_INLINE float DoubleFlashAttentionRowVector(DF df, size_t start_pos,
538538 float m = hn::ReduceMax (df, x_max);
539539 m = std::max (m, old_max);
540540 VF m_vec = hn::Set (df, m);
541- x0 = hn::Exp (df, hn::Sub (x0, m_vec));
542- x1 = hn::Exp (df, hn::Sub (x1, m_vec));
541+ x0 = hn::FastExpMinusOrZero (df, hn::Sub (x0, m_vec));
542+ x1 = hn::FastExpMinusOrZero (df, hn::Sub (x1, m_vec));
543543 float scale = old_d * std::exp (old_max - m);
544544 VF x_sum = hn::Add (x0, x1);
545545 old_d = hn::ReduceSum (df, x_sum) + scale;
@@ -630,7 +630,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
630630 x_sum = Reduce4 (df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
631631 [](auto a, auto b) HWY_ATTR { return hn::Add (a, b); });
632632 }
633- VF4 scale = hn::Mul (old_d_vf, hn::Exp (df4, hn::Sub (old_max_vf, new_max)));
633+ VF4 scale = hn::Mul (
634+ old_d_vf, hn::FastExpMinusOrZero (df4, hn::Sub (old_max_vf, new_max)));
634635 old_d_vf = hn::Add (scale, x_sum);
635636 auto non_zero_mask = hn::Gt (old_d_vf, hn::Set (df4, 0 .0f ));
636637 const VF zero = hn::Zero (df);
@@ -790,7 +791,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
790791 x_6_sum, x_7_sum,
791792 [](auto a, auto b) HWY_ATTR { return hn::Add (a, b); });
792793 }
793- VF8 scale = hn::Mul (old_d_vf, hn::Exp (df8, hn::Sub (old_max_vf, new_max)));
794+ VF8 scale = hn::Mul (
795+ old_d_vf, hn::FastExpMinusOrZero (df8, hn::Sub (old_max_vf, new_max)));
794796 old_d_vf = hn::Add (scale, x_sum);
795797 auto non_zero_mask = hn::Gt (old_d_vf, hn::Set (df8, 0 .0f ));
796798 const VF zero = hn::Zero (df);
0 commit comments