@@ -86,9 +86,7 @@ pub fn replay_attention_reference(
8686 let kv_head = qh / heads_per_kv;
8787
8888 // Extract Q head
89- let q_head: Vec < f64 > = ( 0 ..d_head)
90- . map ( |i| q_i8[ qh * d_head + i] as f64 )
91- . collect ( ) ;
89+ let q_head: Vec < f64 > = ( 0 ..d_head) . map ( |i| q_i8[ qh * d_head + i] as f64 ) . collect ( ) ;
9290
9391 // Compute attention scores: score[t] = q · k_t / sqrt(d)
9492 let scores: Vec < f64 > = ( 0 ..seq_len)
@@ -153,16 +151,18 @@ pub fn replay_attention_roped(
153151 let heads_per_kv = cfg. n_q_heads / cfg. n_kv_heads ;
154152 let inv_sqrt_d = 1.0 / ( d_head as f64 ) . sqrt ( ) ;
155153 let seq_len = kv_cache_k_roped. len ( ) ;
156- let inv_scale = if scale_a. abs ( ) > 1e-30 { 1.0 / scale_a } else { 1.0 } ;
154+ let inv_scale = if scale_a. abs ( ) > 1e-30 {
155+ 1.0 / scale_a
156+ } else {
157+ 1.0
158+ } ;
157159
158160 let mut a = vec ! [ 0i8 ; cfg. hidden_dim] ;
159161
160162 for qh in 0 ..cfg. n_q_heads {
161163 let kv_head = qh / heads_per_kv;
162164
163- let q_head: Vec < f64 > = ( 0 ..d_head)
164- . map ( |i| q_roped[ qh * d_head + i] )
165- . collect ( ) ;
165+ let q_head: Vec < f64 > = ( 0 ..d_head) . map ( |i| q_roped[ qh * d_head + i] ) . collect ( ) ;
166166
167167 // Attention scores: q · k_t / sqrt(d)
168168 let scores: Vec < f64 > = ( 0 ..seq_len)
@@ -192,9 +192,7 @@ pub fn replay_attention_roped(
192192
193193 // Requantize: a_i8 = round(a_f64 / scale_a)
194194 for i in 0 ..d_head {
195- a[ qh * d_head + i] = ( head_out[ i] * inv_scale)
196- . round ( )
197- . clamp ( -128.0 , 127.0 ) as i8 ;
195+ a[ qh * d_head + i] = ( head_out[ i] * inv_scale) . round ( ) . clamp ( -128.0 , 127.0 ) as i8 ;
198196 }
199197 }
200198
@@ -217,17 +215,19 @@ pub fn replay_attention_roped_raw(
217215 let heads_per_kv = cfg. n_q_heads / cfg. n_kv_heads ;
218216 let inv_sqrt_d = 1.0 / ( d_head as f64 ) . sqrt ( ) ;
219217 let seq_len = kv_cache_k_roped. len ( ) ;
220- let inv_scale = if scale_a. abs ( ) > 1e-30 { 1.0 / scale_a } else { 1.0 } ;
218+ let inv_scale = if scale_a. abs ( ) > 1e-30 {
219+ 1.0 / scale_a
220+ } else {
221+ 1.0
222+ } ;
221223
222224 let mut a_i8 = vec ! [ 0i8 ; cfg. hidden_dim] ;
223225 let mut a_f64 = vec ! [ 0.0f64 ; cfg. hidden_dim] ;
224226
225227 for qh in 0 ..cfg. n_q_heads {
226228 let kv_head = qh / heads_per_kv;
227229
228- let q_head: Vec < f64 > = ( 0 ..d_head)
229- . map ( |i| q_roped[ qh * d_head + i] )
230- . collect ( ) ;
230+ let q_head: Vec < f64 > = ( 0 ..d_head) . map ( |i| q_roped[ qh * d_head + i] ) . collect ( ) ;
231231
232232 let scores: Vec < f64 > = ( 0 ..seq_len)
233233 . map ( |t| {
@@ -255,9 +255,7 @@ pub fn replay_attention_roped_raw(
255255 for i in 0 ..d_head {
256256 let idx = qh * d_head + i;
257257 a_f64[ idx] = head_out[ i] ;
258- a_i8[ idx] = ( head_out[ i] * inv_scale)
259- . round ( )
260- . clamp ( -128.0 , 127.0 ) as i8 ;
258+ a_i8[ idx] = ( head_out[ i] * inv_scale) . round ( ) . clamp ( -128.0 , 127.0 ) as i8 ;
261259 }
262260 }
263261
@@ -337,8 +335,16 @@ pub fn measure_attention_diff(
337335 } ;
338336
339337 let n_f = n as f64 ;
340- let frac_eq = if n > 0 { histogram[ 0 ] as f64 / n_f } else { 0.0 } ;
341- let frac_le_1 = if n > 0 { ( histogram[ 0 ] + histogram[ 1 ] ) as f64 / n_f } else { 0.0 } ;
338+ let frac_eq = if n > 0 {
339+ histogram[ 0 ] as f64 / n_f
340+ } else {
341+ 0.0
342+ } ;
343+ let frac_le_1 = if n > 0 {
344+ ( histogram[ 0 ] + histogram[ 1 ] ) as f64 / n_f
345+ } else {
346+ 0.0
347+ } ;
342348 let frac_le_2 = if n > 0 {
343349 ( histogram[ 0 ] + histogram[ 1 ] + histogram[ 2 ] ) as f64 / n_f
344350 } else {
@@ -365,7 +371,11 @@ pub fn measure_attention_diff(
365371/// Returns `None` if the vectors have equal length and the L-infinity difference
366372/// is within `tolerance.max_abs_diff`. Returns `Some(i16::MAX)` if lengths differ
367373/// (malformed input), or `Some(max_diff)` if the tolerance is exceeded.
368- pub fn compare_attention_output ( claimed : & [ i8 ] , replayed : & [ i8 ] , tolerance : & AttentionToleranceConfig ) -> Option < i16 > {
374+ pub fn compare_attention_output (
375+ claimed : & [ i8 ] ,
376+ replayed : & [ i8 ] ,
377+ tolerance : & AttentionToleranceConfig ,
378+ ) -> Option < i16 > {
369379 if claimed. len ( ) != replayed. len ( ) {
370380 return Some ( i16:: MAX ) ;
371381 }
@@ -480,11 +490,17 @@ mod tests {
480490 let claimed = vec ! [ 1i8 , 2 , 3 ] ;
481491 let replayed = vec ! [ 1i8 , 2 , 3 , 4 ] ;
482492 let tol = AttentionToleranceConfig { max_abs_diff : 0 } ;
483- assert_eq ! ( compare_attention_output( & claimed, & replayed, & tol) , Some ( i16 :: MAX ) ) ;
493+ assert_eq ! (
494+ compare_attention_output( & claimed, & replayed, & tol) ,
495+ Some ( i16 :: MAX )
496+ ) ;
484497
485498 // Extended claimed vector is also rejected
486499 let claimed2 = vec ! [ 1i8 , 2 , 3 , 4 , 5 ] ;
487- assert_eq ! ( compare_attention_output( & claimed2, & replayed, & tol) , Some ( i16 :: MAX ) ) ;
500+ assert_eq ! (
501+ compare_attention_output( & claimed2, & replayed, & tol) ,
502+ Some ( i16 :: MAX )
503+ ) ;
488504 }
489505
490506 #[ test]
@@ -508,7 +524,7 @@ mod tests {
508524 #[ test]
509525 fn test_measure_diff_known_values ( ) {
510526 // diffs: 0, 1, 2, 3, 5, 7
511- let claimed = vec ! [ 10i8 , 20 , 30 , 40 , 50 , 60 ] ;
527+ let claimed = vec ! [ 10i8 , 20 , 30 , 40 , 50 , 60 ] ;
512528 let replayed = vec ! [ 10i8 , 19 , 28 , 37 , 45 , 53 ] ;
513529 let stats = measure_attention_diff ( & claimed, & replayed, 3 , 10 ) . unwrap ( ) ;
514530 assert_eq ! ( stats. linf, 7 ) ;
0 commit comments