@@ -203,11 +203,12 @@ def _mhc_projection_bwd_fused(
203203 phi = tl .load (
204204 phi_ptrs , mask = (offs_n_full [:, None ] < N ) & mask_k [None , :], other = 0.0
205205 ) # (BLOCK_SIZE_N, BLOCK_SIZE_K)
206+ grad_ms = tl .load (grad_ms_ptrs , mask = offs_r < M , other = 0.0 , cache_modifier = ".ca" ) # (BLOCK_SIZE_M,)
207+
208+ grad_x = x * (grad_ms * 2 / tl .cast (K , tl .float32 ))[:, None ]
206209 grad_x = tl .dot (
207- grad_h , phi , input_precision = precision , out_dtype = tl .float32
210+ grad_h , phi , acc = grad_x , input_precision = precision , out_dtype = tl .float32
208211 ) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
209- grad_ms = tl .load (grad_ms_ptrs , mask = offs_r < M , other = 0.0 , cache_modifier = ".ca" ) # (BLOCK_SIZE_M,)
210- grad_x += x * (grad_ms * 2 / tl .cast (K , tl .float32 ))[:, None ]
211212 grad_x_ptrs = grad_x_ptr + offs_m [:, None ] * stride_grad_xm + offs_k [None , :] * stride_grad_xk
212213 grad_x = grad_x .to (x .dtype )
213214 tl .store (grad_x_ptrs , grad_x , mask = mask_m [:, None ] & mask_k [None , :])
@@ -1179,7 +1180,8 @@ def _mhc_expand_combine_fwd(
11791180 # Residual connection path: res_out = f @ H_post:
11801181 # (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C)
11811182 # Due to broadcasting, it's equivalent to a multiplicaiton
1182- res_out = f [:, :, None ] * H_post [:, None , :] # (BLOCK_SIZE_M, BLOCK_SIZE_C, n)
1183+ out_acc = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_C , n ), dtype = tl .float32 )
1184+ out_acc = tl .fma (f [:, :, None ], H_post [:, None , :], out_acc )
11831185
11841186 H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl .arange (0 , BLOCK_SIZE_M * n * n )
11851187 H_res = tl .load (H_res_ptr + H_res_offs , mask = H_res_offs < M * n * n , other = 0.0 , cache_modifier = ".ca" )
@@ -1199,7 +1201,6 @@ def _mhc_expand_combine_fwd(
11991201 # + x[:, :, 1] @ H_res[:, 1, :]
12001202 # + x[:, :, 2] @ H_res[:, 2, :]
12011203 # + x[:, :, 3] @ H_res[:, 3, :]
1202- manifold_out_acc = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_C , n ), dtype = tl .float32 )
12031204
12041205 x_reshape = tl .reshape (x , (BLOCK_SIZE_M , BLOCK_SIZE_C , 2 , 2 ))
12051206 x01 , x23 = tl .split (x_reshape ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2)
@@ -1211,14 +1212,12 @@ def _mhc_expand_combine_fwd(
12111212 H_res0 , H_res1 = tl .split (H_res01 ) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n)
12121213 H_res2 , H_res3 = tl .split (H_res23 ) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n)
12131214
1214- manifold_out_acc = tl .fma (x0 [:, :, None ], H_res0 [:, None , :], manifold_out_acc )
1215- manifold_out_acc = tl .fma (x1 [:, :, None ], H_res1 [:, None , :], manifold_out_acc )
1216- manifold_out_acc = tl .fma (x2 [:, :, None ], H_res2 [:, None , :], manifold_out_acc )
1217- manifold_out_acc = tl .fma (x3 [:, :, None ], H_res3 [:, None , :], manifold_out_acc )
1215+ out_acc = tl .fma (x0 [:, :, None ], H_res0 [:, None , :], out_acc )
1216+ out_acc = tl .fma (x1 [:, :, None ], H_res1 [:, None , :], out_acc )
1217+ out_acc = tl .fma (x2 [:, :, None ], H_res2 [:, None , :], out_acc )
1218+ out_acc = tl .fma (x3 [:, :, None ], H_res3 [:, None , :], out_acc )
12181219
1219- manifold_out = manifold_out_acc .to (x .dtype )
1220-
1221- out = manifold_out + res_out
1220+ out = out_acc .to (x .dtype )
12221221 out = tl .reshape (out , (BLOCK_SIZE_M , BLOCK_SIZE_C * n )) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n)
12231222
12241223 output_ptrs = (
@@ -1486,7 +1485,6 @@ def _mhc_expand_combine_with_bias_fwd(
14861485 f_ptrs = f_ptr + offs_m [:, None ] * stride_fm + offs_c [None , :] * stride_fc
14871486 f = tl .load (f_ptrs , mask = mask_m [:, None ] & mask_c [None , :], other = 0.0 )
14881487 bias = tl .load (bias_ptr + offs_c * stride_bias , mask = mask_c , other = 0.0 ) # (BLOCK_SIZE_C,)
1489- f = f + bias [None , :] # (BLOCK_SIZE_M, BLOCK_SIZE_C)
14901488
14911489 offs_H_post = pid_m * BLOCK_SIZE_M * n + tl .arange (0 , BLOCK_SIZE_M * n )
14921490 H_post = tl .load (H_post_ptr + offs_H_post , mask = offs_H_post < M * n , other = 0.0 , cache_modifier = ".ca" )
@@ -1495,7 +1493,9 @@ def _mhc_expand_combine_with_bias_fwd(
14951493 # Residual connection path: res_out = f @ H_post + bias @ H_post:
14961494 # (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C)
14971495 # Due to broadcasting, it's equivalent to a multiplicaiton
1498- res_out = f [:, :, None ] * H_post [:, None , :] # (BLOCK_SIZE_M, BLOCK_SIZE_C, n)
1496+ out_acc = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_C , n ), dtype = tl .float32 )
1497+ out_acc = tl .fma (bias [None , :, None ], H_post [:, None , :], out_acc )
1498+ out_acc = tl .fma (f [:, :, None ], H_post [:, None , :], out_acc )
14991499
15001500 H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl .arange (0 , BLOCK_SIZE_M * n * n )
15011501 H_res = tl .load (H_res_ptr + H_res_offs , mask = H_res_offs < M * n * n , other = 0.0 , cache_modifier = ".ca" )
@@ -1515,7 +1515,6 @@ def _mhc_expand_combine_with_bias_fwd(
15151515 # + x[:, :, 1] @ H_res[:, 1, :]
15161516 # + x[:, :, 2] @ H_res[:, 2, :]
15171517 # + x[:, :, 3] @ H_res[:, 3, :]
1518- manifold_out_acc = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_C , n ), dtype = tl .float32 )
15191518
15201519 x_reshape = tl .reshape (x , (BLOCK_SIZE_M , BLOCK_SIZE_C , 2 , 2 ))
15211520 x01 , x23 = tl .split (x_reshape ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2)
@@ -1527,14 +1526,12 @@ def _mhc_expand_combine_with_bias_fwd(
15271526 H_res0 , H_res1 = tl .split (H_res01 ) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n)
15281527 H_res2 , H_res3 = tl .split (H_res23 ) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n)
15291528
1530- manifold_out_acc = tl .fma (x0 [:, :, None ], H_res0 [:, None , :], manifold_out_acc )
1531- manifold_out_acc = tl .fma (x1 [:, :, None ], H_res1 [:, None , :], manifold_out_acc )
1532- manifold_out_acc = tl .fma (x2 [:, :, None ], H_res2 [:, None , :], manifold_out_acc )
1533- manifold_out_acc = tl .fma (x3 [:, :, None ], H_res3 [:, None , :], manifold_out_acc )
1534-
1535- manifold_out = manifold_out_acc .to (x .dtype )
1529+ out_acc = tl .fma (x0 [:, :, None ], H_res0 [:, None , :], out_acc )
1530+ out_acc = tl .fma (x1 [:, :, None ], H_res1 [:, None , :], out_acc )
1531+ out_acc = tl .fma (x2 [:, :, None ], H_res2 [:, None , :], out_acc )
1532+ out_acc = tl .fma (x3 [:, :, None ], H_res3 [:, None , :], out_acc )
15361533
1537- out = manifold_out + res_out
1534+ out = out_acc . to ( x . dtype )
15381535 out = tl .reshape (out , (BLOCK_SIZE_M , BLOCK_SIZE_C * n )) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n)
15391536
15401537 output_ptrs = (
@@ -1636,7 +1633,6 @@ def _mhc_expand_combine_with_bias_bwd(
16361633 f = tl .load (f_ptrs , mask = mask_m [:, None ] & mask_c [None , :], other = 0.0 )
16371634
16381635 bias = tl .load (bias_ptr + offs_c * stride_bias , mask = mask_c , other = 0.0 ) # (BLOCK_SIZE_C,)
1639- f = f + bias [None , :] # (BLOCK_SIZE_M, BLOCK_SIZE_C)
16401636
16411637 H_post_offs = pid_m * BLOCK_SIZE_M * n + tl .arange (0 , BLOCK_SIZE_M * n )
16421638 H_post = tl .load (H_post_ptr + H_post_offs , mask = H_post_offs < M * n , other = 0.0 )
@@ -1665,6 +1661,13 @@ def _mhc_expand_combine_with_bias_bwd(
16651661 input_precision = precision ,
16661662 out_dtype = tl .float32 ,
16671663 ) # (BLOCK_SIZE_M, 1, n)
1664+ grad_H_post = tl .dot (
1665+ tl .broadcast_to (bias [None , None , :], (BLOCK_SIZE_M , 1 , BLOCK_SIZE_C )),
1666+ tl .reshape (grad_out , (BLOCK_SIZE_M , BLOCK_SIZE_C , n )),
1667+ acc = grad_H_post ,
1668+ input_precision = precision ,
1669+ out_dtype = tl .float32 ,
1670+ ) # (BLOCK_SIZE_M, 1, n)
16681671 grad_H_post = tl .reshape (grad_H_post , (BLOCK_SIZE_M * n ,)) # (BLOCK_SIZE_M * n)
16691672 offs_grad_H_post = pid_m * BLOCK_SIZE_M * n + tl .arange (0 , BLOCK_SIZE_M * n )
16701673 grad_H_post_ptrs = grad_H_post_ptr + offs_grad_H_post
0 commit comments