Skip to content

Commit dd8f247

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent fc6b65f commit dd8f247

8 files changed

Lines changed: 155 additions & 127 deletions

File tree

SOL.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
hardware_specs = {
1919
"B200": {
2020
"MEM_BANDWIDTH": 8.0e12, # 8 TB/s
21-
"FP32_FLOPS": 75.0e12, # 75 TFLOPS (Vector FP32 for tiny inner dims)
22-
"TF32_FLOPS": 1.125e15, # 1,125 TFLOPS (Dense TF32 Tensor Core)
21+
"FP32_FLOPS": 75.0e12, # 75 TFLOPS (Vector FP32 for tiny inner dims)
22+
"TF32_FLOPS": 1.125e15, # 1,125 TFLOPS (Dense TF32 Tensor Core)
2323
}
2424
}
2525

@@ -30,9 +30,10 @@
3030
print(f" Format: {dtype} ({bytes_per_elem} bytes/elem)")
3131
print(f"================================================================================\n")
3232

33+
3334
def print_sol_breakdown(name, mem_gb, flops_g=None, use_tf32=False):
3435
bw_gb_s = hardware_specs[GPU]["MEM_BANDWIDTH"] / 1e9
35-
36+
3637
# Select peak FLOPs based on TensorCore utilization
3738
if use_tf32:
3839
peak_flops_g = hardware_specs[GPU]["TF32_FLOPS"] / 1e9
@@ -42,13 +43,13 @@ def print_sol_breakdown(name, mem_gb, flops_g=None, use_tf32=False):
4243
math_type = "FP32 Vector"
4344

4445
time_mem_ms = (mem_gb / bw_gb_s) * 1000
45-
46+
4647
print(f"[{name}]")
4748
if flops_g is not None:
4849
time_math_ms = (flops_g / peak_flops_g) * 1000
4950
sol_time = max(time_mem_ms, time_math_ms)
5051
bound = "FLOPS bounded" if time_math_ms > time_mem_ms else "Memory bounded"
51-
52+
5253
print(f" ├─ Architecture : {math_type}")
5354
print(f" ├─ Total Mem R/W: {mem_gb:8.4f} GB")
5455
print(f" ├─ Total Math : {flops_g:8.4f} GFLOPS")
@@ -62,6 +63,7 @@ def print_sol_breakdown(name, mem_gb, flops_g=None, use_tf32=False):
6263
print(f" ├─ Mem Time : {time_mem_ms:8.4f} ms")
6364
print(f" └─ SOL Time : {sol_time:8.4f} ms ({bound})\n")
6465

66+
6567
# ---------------------------------------------------------
6668
# 1. Projection kernel: (B, T, n*C) @ (n*C, 32)
6769
# ---------------------------------------------------------
@@ -133,14 +135,14 @@ def print_sol_breakdown(name, mem_gb, flops_g=None, use_tf32=False):
133135
post_in1_2_gb = B * T * 1 * C * bytes_per_elem / 1e9
134136
post_in2_1_gb = B * T * n * n * bytes_per_elem / 1e9
135137
post_in2_2_gb = B * T * n * C * bytes_per_elem / 1e9
136-
post_out_gb = B * T * n * C * bytes_per_elem / 1e9
138+
post_out_gb = B * T * n * C * bytes_per_elem / 1e9
137139

138140
post_mem_gb = post_in1_1_gb + post_in1_2_gb + post_in2_1_gb + post_in2_2_gb + post_out_gb
139141

140142
flops_term1_g = B * T * (2 * n * 1 * C) / 1e9
141143
flops_term2_g = B * T * (2 * n * n * C) / 1e9
142-
flops_add_g = B * T * n * C / 1e9
143-
post_flops_g = flops_term1_g + flops_term2_g + flops_add_g
144+
flops_add_g = B * T * n * C / 1e9
145+
post_flops_g = flops_term1_g + flops_term2_g + flops_add_g
144146

145147
print(f"================================================================================")
146148
print(f"5. Post + Res Kernel (Fused): (B, T, n, 1) @ (B, T, 1, C) + (B, T, n, n) @ (B, T, n, C)")

cutile_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,4 +961,4 @@ def fused_proj_rms(x: Tensor, weight: Tensor, eps: float = 1e-6) -> Tuple[Tensor
961961
proj: [M, N] = x @ weight^T
962962
r: [M, 1] = 1 / (||x|| / sqrt(K) + eps)
963963
"""
964-
return FusedProjRms.apply(x, weight, eps)
964+
return FusedProjRms.apply(x, weight, eps)

mhc_bench.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
mHCSinkhornRef,
2727
mHCElementwiseRef,
2828
mHCPreRef,
29-
mHCPostResRef
29+
mHCPostResRef,
3030
)
3131

32+
3233
def run_sinkhorn_triton(B, T, n, dtype, device, iters, do_backward):
3334
nvtx.range_push("mhc_sinkhorn_triton_fwd")
3435
x = torch.randn((B, T, n, n), device=device, dtype=dtype, requires_grad=do_backward)
@@ -39,6 +40,7 @@ def run_sinkhorn_triton(B, T, n, dtype, device, iters, do_backward):
3940
y.sum().backward()
4041
nvtx.range_pop()
4142

43+
4244
def run_sinkhorn_cutile(B, T, n, dtype, device, iters, do_backward):
4345
nvtx.range_push("mhc_sinkhorn_cutile_fwd")
4446
x = torch.randn((B, T, n, n), device=device, dtype=dtype, requires_grad=do_backward)
@@ -49,6 +51,7 @@ def run_sinkhorn_cutile(B, T, n, dtype, device, iters, do_backward):
4951
y.sum().backward()
5052
nvtx.range_pop()
5153

54+
5255
def run_sinkhorn_compile(B, T, n, dtype, device, iters, do_backward):
5356
nvtx.range_push("mhc_sinkhorn_compile_fwd")
5457
x = torch.randn((B, T, n, n), device=device, dtype=dtype, requires_grad=do_backward)
@@ -59,6 +62,7 @@ def run_sinkhorn_compile(B, T, n, dtype, device, iters, do_backward):
5962
y.sum().backward()
6063
nvtx.range_pop()
6164

65+
6266
def run_sinkhorn(B, T, n, dtype, device, iters, do_backward):
6367
run_sinkhorn_cutile(B, T, n, dtype, device, iters, do_backward)
6468
run_sinkhorn_triton(B, T, n, dtype, device, iters, do_backward)
@@ -78,6 +82,7 @@ def run_projection_triton(B, T, n, C, dtype, device, do_backward):
7882
(Hs.sum() + r.sum()).backward()
7983
nvtx.range_pop()
8084

85+
8186
def run_projection_cutile(B, T, n, C, dtype, device, do_backward):
8287
nC = n * C
8388
N = 2 * n + n * n
@@ -91,6 +96,7 @@ def run_projection_cutile(B, T, n, C, dtype, device, do_backward):
9196
(Hs.sum() + r.sum()).backward()
9297
nvtx.range_pop()
9398

99+
94100
def run_projection_compile(B, T, n, C, dtype, device, do_backward):
95101
nC = n * C
96102
N = 2 * n + n * n
@@ -104,6 +110,7 @@ def run_projection_compile(B, T, n, C, dtype, device, do_backward):
104110
(Hs.sum() + r.sum()).backward()
105111
nvtx.range_pop()
106112

113+
107114
def run_projection(B, T, n, C, dtype, device, do_backward):
108115
run_projection_cutile(B, T, n, C, dtype, device, do_backward)
109116
run_projection_triton(B, T, n, C, dtype, device, do_backward)
@@ -124,6 +131,7 @@ def run_elementwise_triton(B, T, n, dtype, device, do_backward):
124131
out.sum().backward()
125132
nvtx.range_pop()
126133

134+
127135
def run_elementwise_compile(B, T, n, dtype, device, do_backward):
128136
N = 2 * n + n * n
129137
nvtx.range_push("mhc_elementwise_compile_fwd")
@@ -138,6 +146,7 @@ def run_elementwise_compile(B, T, n, dtype, device, do_backward):
138146
out.sum().backward()
139147
nvtx.range_pop()
140148

149+
141150
def run_elementwise(B, T, n, dtype, device, do_backward):
142151
run_elementwise_triton(B, T, n, dtype, device, do_backward)
143152
run_elementwise_compile(B, T, n, dtype, device, do_backward)
@@ -154,6 +163,7 @@ def run_pre_triton(B, T, n, C, dtype, device, do_backward):
154163
out.sum().backward()
155164
nvtx.range_pop()
156165

166+
157167
def run_pre_cutile(B, T, n, C, dtype, device, do_backward):
158168
nvtx.range_push("mhc_pre_cutile_fwd")
159169
x = torch.randn(B, T, n, C, dtype=dtype, requires_grad=True, device=device)
@@ -165,6 +175,7 @@ def run_pre_cutile(B, T, n, C, dtype, device, do_backward):
165175
out.sum().backward()
166176
nvtx.range_pop()
167177

178+
168179
def run_pre_compile(B, T, n, C, dtype, device, do_backward):
169180
nvtx.range_push("mhc_pre_compile_fwd")
170181
x = torch.randn(B, T, n, C, dtype=dtype, requires_grad=True, device=device)
@@ -176,6 +187,7 @@ def run_pre_compile(B, T, n, C, dtype, device, do_backward):
176187
out.sum().backward()
177188
nvtx.range_pop()
178189

190+
179191
def run_pre(B, T, n, C, dtype, device, do_backward):
180192
run_pre_cutile(B, T, n, C, dtype, device, do_backward)
181193
run_pre_triton(B, T, n, C, dtype, device, do_backward)
@@ -195,6 +207,7 @@ def run_post_res_triton(B, T, n, C, dtype, device, do_backward):
195207
out.sum().backward()
196208
nvtx.range_pop()
197209

210+
198211
def run_post_res_cutile(B, T, n, C, dtype, device, do_backward):
199212
nvtx.range_push("mhc_post_res_cutile_fwd")
200213
x = torch.randn(B, T, n, C, dtype=dtype, requires_grad=True, device=device)
@@ -208,6 +221,7 @@ def run_post_res_cutile(B, T, n, C, dtype, device, do_backward):
208221
out.sum().backward()
209222
nvtx.range_pop()
210223

224+
211225
def run_post_res_compile(B, T, n, C, dtype, device, do_backward):
212226
nvtx.range_push("mhc_post_res_compile_fwd")
213227
x = torch.randn(B, T, n, C, dtype=dtype, requires_grad=True, device=device)
@@ -221,14 +235,20 @@ def run_post_res_compile(B, T, n, C, dtype, device, do_backward):
221235
out.sum().backward()
222236
nvtx.range_pop()
223237

238+
224239
def run_post_res(B, T, n, C, dtype, device, do_backward):
225240
run_post_res_cutile(B, T, n, C, dtype, device, do_backward)
226241
run_post_res_triton(B, T, n, C, dtype, device, do_backward)
227242
run_post_res_compile(B, T, n, C, dtype, device, do_backward)
228243

244+
229245
def main():
230246
parser = argparse.ArgumentParser()
231-
parser.add_argument("--operation", choices=["sinkhorn", "projection", "elementwise", "pre", "post_res", "all"], required=True)
247+
parser.add_argument(
248+
"--operation",
249+
choices=["sinkhorn", "projection", "elementwise", "pre", "post_res", "all"],
250+
required=True,
251+
)
232252
parser.add_argument("--dtype", choices=["float32", "bfloat16"], default="float32")
233253
parser.add_argument("--warmup", type=int, default=2)
234254
parser.add_argument("--iters", type=int, default=1)

native_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torch.nn.functional as F
33

4+
45
@torch.compile
56
def mHCProjectionRef(x, phi):
67
"""
@@ -151,4 +152,3 @@ def mHCPostResRef(f, H_post, x, H_res, n):
151152
out = H_post @ f + H_res @ x # (B, T, n, C)
152153

153154
return out
154-

tests/pytorch/test_mhc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def mHCAggregateRef(x, H_pre, n):
153153

154154
return out
155155

156+
156157
@torch.compile
157158
def mHCExpandCombineRef(f, bias, H_post, x, H_res, n):
158159
"""
@@ -177,6 +178,7 @@ def mHCExpandCombineRef(f, bias, H_post, x, H_res, n):
177178

178179
return out
179180

181+
180182
@dataclass
181183
class MHCConfig:
182184
s: int = 2048 # Sequence length
@@ -413,6 +415,7 @@ def test_mhc_sinkhorn_knopp(cfg: MHCConfig, dtype, recompute):
413415

414416
torch.testing.assert_close(x.grad, x_ref.grad, **tols)
415417

418+
416419
@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc)
417420
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
418421
def test_mhc_aggregate(cfg: MHCConfig, dtype):
@@ -456,7 +459,7 @@ def test_mhc_expand_combine(cfg: MHCConfig, dtype, with_bias):
456459
H_res = torch.randn(s, b, n, n, device="cuda", requires_grad=True, dtype=dtype)
457460

458461
f_ref = f.detach().clone().requires_grad_(True)
459-
bias_ref = None if bias is None else bias.detach().clone().requires_grad_(True)
462+
bias_ref = None if bias is None else bias.detach().clone().requires_grad_(True)
460463
H_post_ref = H_post.detach().clone().requires_grad_(True)
461464
x_ref = x.detach().clone().requires_grad_(True)
462465
H_res_ref = H_res.detach().clone().requires_grad_(True)

0 commit comments

Comments
 (0)