Skip to content

Commit 4657371

Browse files
committed
jesus
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
1 parent d768217 commit 4657371

2 files changed

Lines changed: 49 additions & 42 deletions

File tree

tests/pytorch/test_mhc.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
torch.backends.cuda.matmul.allow_tf32 = False
2424

2525

26-
@torch.compile
2726
def mHCProjectionRef(x, phi):
2827
"""
2928
Reference operator for mHC's projection building operation.
@@ -48,7 +47,6 @@ def mHCProjectionRef(x, phi):
4847
return Hs.to(x_dtype), ms
4948

5049

51-
@torch.compile
5250
def mHCScaleRef(H, alpha, beta, ms, n):
5351
"""
5452
Reference operator for mHC's pre and post calculations
@@ -101,7 +99,6 @@ def mHCScaleRef(H, alpha, beta, ms, n):
10199
return out.to(H_dtype)
102100

103101

104-
@torch.compile
105102
def mHCSinkhornRef(H_res, n=4, iterations=20):
106103
"""
107104
Sinkhorn-Knopp algorithm to convert a matrix into a doubly stochastic matrix.
@@ -136,7 +133,6 @@ def mHCSinkhornRef(H_res, n=4, iterations=20):
136133
return H_res_out
137134

138135

139-
@torch.compile
140136
def mHCAggregateRef(x, H_pre, n):
141137
"""
142138
Reference operator for applying mHC's pre matrix H to a vector x.
@@ -153,7 +149,6 @@ def mHCAggregateRef(x, H_pre, n):
153149

154150
return out
155151

156-
@torch.compile
157152
def mHCExpandCombineRef(f, bias, H_post, x, H_res, n):
158153
"""
159154
Reference operator for applying mHC's post transformation and residual transformation
@@ -167,15 +162,25 @@ def mHCExpandCombineRef(f, bias, H_post, x, H_res, n):
167162

168163
s, b, C, n = x.shape
169164

165+
# My triton kernels use FMA and MMA instructions with fp32 accumulator for bf16 test cases
166+
# which has better numerical stability than this pytorch implementation
167+
# To match the kernel's accuracy we need to cast to fp32 here to match kernels' result
168+
input_dtype = f.dtype
169+
f = f.to(torch.float32)
170+
bias = bias.to(torch.float32) if bias is not None else None
171+
H_post = H_post.to(torch.float32)
172+
x = x.to(torch.float32)
173+
H_res = H_res.to(torch.float32)
174+
170175
if bias is not None:
171-
f = f + bias
176+
f = f + bias[None, None, :]
172177

173178
f = f.view(s, b, C, 1)
174179
H_post = H_post.view(s, b, 1, n)
175180

176181
out = f @ H_post + x @ H_res # (s, b, C, n)
177182

178-
return out
183+
return out.to(input_dtype)
179184

180185
@dataclass
181186
class MHCConfig:
@@ -232,27 +237,27 @@ def desc(cfg):
232237
MHCConfig(
233238
8,
234239
128,
235-
16 * 192,
240+
5129,
236241
),
237242
MHCConfig(
238243
8,
239-
1,
240-
16 * 500,
244+
512,
245+
8000,
241246
),
242247
MHCConfig(
243-
8,
244-
128,
245-
16 * 512,
248+
4,
249+
1024,
250+
8192,
246251
),
247252
MHCConfig(
248-
8,
249-
1,
250-
16 * 376,
253+
2,
254+
4096,
255+
8192,
251256
),
252257
MHCConfig(
253258
8,
254259
128,
255-
16 * 1024,
260+
16384,
256261
),
257262
]
258263

@@ -449,8 +454,7 @@ def test_mhc_expand_combine(cfg: MHCConfig, dtype, with_bias):
449454
f = torch.randn(s, b, C, device="cuda", requires_grad=True, dtype=dtype)
450455
bias = None
451456
if with_bias:
452-
bias_raw = torch.randn(C, device="cuda", requires_grad=True, dtype=dtype) * 0.1
453-
bias = bias_raw.detach().clone().requires_grad_(True)
457+
bias = torch.randn(C, device="cuda", requires_grad=True, dtype=dtype)
454458
H_post = torch.randn(s, b, n, device="cuda", requires_grad=True, dtype=dtype)
455459
x = torch.randn(s, b, C, n, device="cuda", requires_grad=True, dtype=dtype)
456460
H_res = torch.randn(s, b, n, n, device="cuda", requires_grad=True, dtype=dtype)

transformer_engine/common/triton/mhc.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,12 @@ def _mhc_projection_bwd_fused(
203203
phi = tl.load(
204204
phi_ptrs, mask=(offs_n_full[:, None] < N) & mask_k[None, :], other=0.0
205205
) # (BLOCK_SIZE_N, BLOCK_SIZE_K)
206+
grad_ms = tl.load(grad_ms_ptrs, mask=offs_r < M, other=0.0, cache_modifier=".ca") # (BLOCK_SIZE_M,)
207+
208+
grad_x = x * (grad_ms * 2 / tl.cast(K, tl.float32))[:, None]
206209
grad_x = tl.dot(
207-
grad_h, phi, input_precision=precision, out_dtype=tl.float32
210+
grad_h, phi, acc=grad_x, input_precision=precision, out_dtype=tl.float32
208211
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
209-
grad_ms = tl.load(grad_ms_ptrs, mask=offs_r < M, other=0.0, cache_modifier=".ca") # (BLOCK_SIZE_M,)
210-
grad_x += x * (grad_ms * 2 / tl.cast(K, tl.float32))[:, None]
211212
grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_k[None, :] * stride_grad_xk
212213
grad_x = grad_x.to(x.dtype)
213214
tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_k[None, :])
@@ -1179,7 +1180,8 @@ def _mhc_expand_combine_fwd(
11791180
# Residual connection path: res_out = f @ H_post:
11801181
# (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C)
11811182
# Due to broadcasting, it's equivalent to a multiplicaiton
1182-
res_out = f[:, :, None ] * H_post[:, None, :] # (BLOCK_SIZE_M, BLOCK_SIZE_C, n)
1183+
out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32)
1184+
out_acc = tl.fma(f[:, :, None], H_post[:, None, :], out_acc)
11831185

11841186
H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n)
11851187
H_res = tl.load(H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0, cache_modifier=".ca")
@@ -1199,7 +1201,6 @@ def _mhc_expand_combine_fwd(
11991201
# + x[:, :, 1] @ H_res[:, 1, :]
12001202
# + x[:, :, 2] @ H_res[:, 2, :]
12011203
# + x[:, :, 3] @ H_res[:, 3, :]
1202-
manifold_out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32)
12031204

12041205
x_reshape = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2))
12051206
x01, x23 = tl.split(x_reshape) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2)
@@ -1211,14 +1212,12 @@ def _mhc_expand_combine_fwd(
12111212
H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n)
12121213
H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n)
12131214

1214-
manifold_out_acc = tl.fma(x0[:, :, None], H_res0[:, None, :], manifold_out_acc)
1215-
manifold_out_acc = tl.fma(x1[:, :, None], H_res1[:, None, :], manifold_out_acc)
1216-
manifold_out_acc = tl.fma(x2[:, :, None], H_res2[:, None, :], manifold_out_acc)
1217-
manifold_out_acc = tl.fma(x3[:, :, None], H_res3[:, None, :], manifold_out_acc)
1215+
out_acc = tl.fma(x0[:, :, None], H_res0[:, None, :], out_acc)
1216+
out_acc = tl.fma(x1[:, :, None], H_res1[:, None, :], out_acc)
1217+
out_acc = tl.fma(x2[:, :, None], H_res2[:, None, :], out_acc)
1218+
out_acc = tl.fma(x3[:, :, None], H_res3[:, None, :], out_acc)
12181219

1219-
manifold_out = manifold_out_acc.to(x.dtype)
1220-
1221-
out = manifold_out + res_out
1220+
out = out_acc.to(x.dtype)
12221221
out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n)
12231222

12241223
output_ptrs = (
@@ -1486,7 +1485,6 @@ def _mhc_expand_combine_with_bias_fwd(
14861485
f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc
14871486
f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0)
14881487
bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,)
1489-
f = f + bias[None, :] # (BLOCK_SIZE_M, BLOCK_SIZE_C)
14901488

14911489
offs_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n)
14921490
H_post = tl.load(H_post_ptr + offs_H_post, mask=offs_H_post < M * n, other=0.0, cache_modifier=".ca")
@@ -1495,7 +1493,9 @@ def _mhc_expand_combine_with_bias_fwd(
14951493
# Residual connection path: res_out = f @ H_post + bias @ H_post:
14961494
# (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C)
14971495
# Due to broadcasting, it's equivalent to a multiplicaiton
1498-
res_out = f[:, :, None] * H_post[:, None, :] # (BLOCK_SIZE_M, BLOCK_SIZE_C, n)
1496+
out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32)
1497+
out_acc = tl.fma(bias[None, :, None], H_post[:, None, :], out_acc)
1498+
out_acc = tl.fma(f[:, :, None], H_post[:, None, :], out_acc)
14991499

15001500
H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n)
15011501
H_res = tl.load(H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0, cache_modifier=".ca")
@@ -1515,7 +1515,6 @@ def _mhc_expand_combine_with_bias_fwd(
15151515
# + x[:, :, 1] @ H_res[:, 1, :]
15161516
# + x[:, :, 2] @ H_res[:, 2, :]
15171517
# + x[:, :, 3] @ H_res[:, 3, :]
1518-
manifold_out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32)
15191518

15201519
x_reshape = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2))
15211520
x01, x23 = tl.split(x_reshape) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2)
@@ -1527,14 +1526,12 @@ def _mhc_expand_combine_with_bias_fwd(
15271526
H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n)
15281527
H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n)
15291528

1530-
manifold_out_acc = tl.fma(x0[:, :, None], H_res0[:, None, :], manifold_out_acc)
1531-
manifold_out_acc = tl.fma(x1[:, :, None], H_res1[:, None, :], manifold_out_acc)
1532-
manifold_out_acc = tl.fma(x2[:, :, None], H_res2[:, None, :], manifold_out_acc)
1533-
manifold_out_acc = tl.fma(x3[:, :, None], H_res3[:, None, :], manifold_out_acc)
1534-
1535-
manifold_out = manifold_out_acc.to(x.dtype)
1529+
out_acc = tl.fma(x0[:, :, None], H_res0[:, None, :], out_acc)
1530+
out_acc = tl.fma(x1[:, :, None], H_res1[:, None, :], out_acc)
1531+
out_acc = tl.fma(x2[:, :, None], H_res2[:, None, :], out_acc)
1532+
out_acc = tl.fma(x3[:, :, None], H_res3[:, None, :], out_acc)
15361533

1537-
out = manifold_out + res_out
1534+
out = out_acc.to(x.dtype)
15381535
out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n)
15391536

15401537
output_ptrs = (
@@ -1636,7 +1633,6 @@ def _mhc_expand_combine_with_bias_bwd(
16361633
f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0)
16371634

16381635
bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,)
1639-
f = f + bias[None, :] # (BLOCK_SIZE_M, BLOCK_SIZE_C)
16401636

16411637
H_post_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n)
16421638
H_post = tl.load(H_post_ptr + H_post_offs, mask=H_post_offs < M * n, other=0.0)
@@ -1665,6 +1661,13 @@ def _mhc_expand_combine_with_bias_bwd(
16651661
input_precision=precision,
16661662
out_dtype=tl.float32,
16671663
) # (BLOCK_SIZE_M, 1, n)
1664+
grad_H_post = tl.dot(
1665+
tl.broadcast_to(bias[None, None, :], (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)),
1666+
tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)),
1667+
acc=grad_H_post,
1668+
input_precision=precision,
1669+
out_dtype=tl.float32,
1670+
) # (BLOCK_SIZE_M, 1, n)
16681671
grad_H_post = tl.reshape(grad_H_post, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n)
16691672
offs_grad_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n)
16701673
grad_H_post_ptrs = grad_H_post_ptr + offs_grad_H_post

0 commit comments

Comments
 (0)