Skip to content

Commit af5e13a

Browse files
Nikhil0250copybara-github
authored andcommitted
Replace remaining occurrences of Exp with FastExpMinusOrZero in flash attention.
PiperOrigin-RevId: 882817324
1 parent 0da94e5 commit af5e13a

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

gemma/flash_attention.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ HWY_INLINE float SingleFlashAttentionRowVector(DF df, size_t start_pos,
489489
}
490490
float m = hn::ReduceMax(df, x);
491491
m = std::max(m, old_max);
492-
x = hn::Exp(df, hn::Sub(x, hn::Set(df, m)));
492+
x = hn::FastExpMinusOrZero(df, hn::Sub(x, hn::Set(df, m)));
493493
float scale = old_d * std::exp(old_max - m);
494494
old_d = hn::ReduceSum(df, x) + scale;
495495
old_max = m;
@@ -538,8 +538,8 @@ HWY_INLINE float DoubleFlashAttentionRowVector(DF df, size_t start_pos,
538538
float m = hn::ReduceMax(df, x_max);
539539
m = std::max(m, old_max);
540540
VF m_vec = hn::Set(df, m);
541-
x0 = hn::Exp(df, hn::Sub(x0, m_vec));
542-
x1 = hn::Exp(df, hn::Sub(x1, m_vec));
541+
x0 = hn::FastExpMinusOrZero(df, hn::Sub(x0, m_vec));
542+
x1 = hn::FastExpMinusOrZero(df, hn::Sub(x1, m_vec));
543543
float scale = old_d * std::exp(old_max - m);
544544
VF x_sum = hn::Add(x0, x1);
545545
old_d = hn::ReduceSum(df, x_sum) + scale;
@@ -630,7 +630,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
630630
x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
631631
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
632632
}
633-
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
633+
VF4 scale = hn::Mul(
634+
old_d_vf, hn::FastExpMinusOrZero(df4, hn::Sub(old_max_vf, new_max)));
634635
old_d_vf = hn::Add(scale, x_sum);
635636
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f));
636637
const VF zero = hn::Zero(df);
@@ -790,7 +791,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
790791
x_6_sum, x_7_sum,
791792
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
792793
}
793-
VF8 scale = hn::Mul(old_d_vf, hn::Exp(df8, hn::Sub(old_max_vf, new_max)));
794+
VF8 scale = hn::Mul(
795+
old_d_vf, hn::FastExpMinusOrZero(df8, hn::Sub(old_max_vf, new_max)));
794796
old_d_vf = hn::Add(scale, x_sum);
795797
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df8, 0.0f));
796798
const VF zero = hn::Zero(df);

0 commit comments

Comments
 (0)