2626 mHCSinkhornRef ,
2727 mHCElementwiseRef ,
2828 mHCPreRef ,
29- mHCPostResRef
29+ mHCPostResRef ,
3030)
3131
32+
3233def run_sinkhorn_triton (B , T , n , dtype , device , iters , do_backward ):
3334 nvtx .range_push ("mhc_sinkhorn_triton_fwd" )
3435 x = torch .randn ((B , T , n , n ), device = device , dtype = dtype , requires_grad = do_backward )
@@ -39,6 +40,7 @@ def run_sinkhorn_triton(B, T, n, dtype, device, iters, do_backward):
3940 y .sum ().backward ()
4041 nvtx .range_pop ()
4142
43+
4244def run_sinkhorn_cutile (B , T , n , dtype , device , iters , do_backward ):
4345 nvtx .range_push ("mhc_sinkhorn_cutile_fwd" )
4446 x = torch .randn ((B , T , n , n ), device = device , dtype = dtype , requires_grad = do_backward )
@@ -49,6 +51,7 @@ def run_sinkhorn_cutile(B, T, n, dtype, device, iters, do_backward):
4951 y .sum ().backward ()
5052 nvtx .range_pop ()
5153
54+
5255def run_sinkhorn_compile (B , T , n , dtype , device , iters , do_backward ):
5356 nvtx .range_push ("mhc_sinkhorn_compile_fwd" )
5457 x = torch .randn ((B , T , n , n ), device = device , dtype = dtype , requires_grad = do_backward )
@@ -59,6 +62,7 @@ def run_sinkhorn_compile(B, T, n, dtype, device, iters, do_backward):
5962 y .sum ().backward ()
6063 nvtx .range_pop ()
6164
65+
6266def run_sinkhorn (B , T , n , dtype , device , iters , do_backward ):
6367 run_sinkhorn_cutile (B , T , n , dtype , device , iters , do_backward )
6468 run_sinkhorn_triton (B , T , n , dtype , device , iters , do_backward )
@@ -78,6 +82,7 @@ def run_projection_triton(B, T, n, C, dtype, device, do_backward):
7882 (Hs .sum () + r .sum ()).backward ()
7983 nvtx .range_pop ()
8084
85+
8186def run_projection_cutile (B , T , n , C , dtype , device , do_backward ):
8287 nC = n * C
8388 N = 2 * n + n * n
@@ -91,6 +96,7 @@ def run_projection_cutile(B, T, n, C, dtype, device, do_backward):
9196 (Hs .sum () + r .sum ()).backward ()
9297 nvtx .range_pop ()
9398
99+
94100def run_projection_compile (B , T , n , C , dtype , device , do_backward ):
95101 nC = n * C
96102 N = 2 * n + n * n
@@ -104,6 +110,7 @@ def run_projection_compile(B, T, n, C, dtype, device, do_backward):
104110 (Hs .sum () + r .sum ()).backward ()
105111 nvtx .range_pop ()
106112
113+
107114def run_projection (B , T , n , C , dtype , device , do_backward ):
108115 run_projection_cutile (B , T , n , C , dtype , device , do_backward )
109116 run_projection_triton (B , T , n , C , dtype , device , do_backward )
@@ -124,6 +131,7 @@ def run_elementwise_triton(B, T, n, dtype, device, do_backward):
124131 out .sum ().backward ()
125132 nvtx .range_pop ()
126133
134+
127135def run_elementwise_compile (B , T , n , dtype , device , do_backward ):
128136 N = 2 * n + n * n
129137 nvtx .range_push ("mhc_elementwise_compile_fwd" )
@@ -138,6 +146,7 @@ def run_elementwise_compile(B, T, n, dtype, device, do_backward):
138146 out .sum ().backward ()
139147 nvtx .range_pop ()
140148
149+
141150def run_elementwise (B , T , n , dtype , device , do_backward ):
142151 run_elementwise_triton (B , T , n , dtype , device , do_backward )
143152 run_elementwise_compile (B , T , n , dtype , device , do_backward )
@@ -154,6 +163,7 @@ def run_pre_triton(B, T, n, C, dtype, device, do_backward):
154163 out .sum ().backward ()
155164 nvtx .range_pop ()
156165
166+
157167def run_pre_cutile (B , T , n , C , dtype , device , do_backward ):
158168 nvtx .range_push ("mhc_pre_cutile_fwd" )
159169 x = torch .randn (B , T , n , C , dtype = dtype , requires_grad = True , device = device )
@@ -165,6 +175,7 @@ def run_pre_cutile(B, T, n, C, dtype, device, do_backward):
165175 out .sum ().backward ()
166176 nvtx .range_pop ()
167177
178+
168179def run_pre_compile (B , T , n , C , dtype , device , do_backward ):
169180 nvtx .range_push ("mhc_pre_compile_fwd" )
170181 x = torch .randn (B , T , n , C , dtype = dtype , requires_grad = True , device = device )
@@ -176,6 +187,7 @@ def run_pre_compile(B, T, n, C, dtype, device, do_backward):
176187 out .sum ().backward ()
177188 nvtx .range_pop ()
178189
190+
179191def run_pre (B , T , n , C , dtype , device , do_backward ):
180192 run_pre_cutile (B , T , n , C , dtype , device , do_backward )
181193 run_pre_triton (B , T , n , C , dtype , device , do_backward )
@@ -195,6 +207,7 @@ def run_post_res_triton(B, T, n, C, dtype, device, do_backward):
195207 out .sum ().backward ()
196208 nvtx .range_pop ()
197209
210+
198211def run_post_res_cutile (B , T , n , C , dtype , device , do_backward ):
199212 nvtx .range_push ("mhc_post_res_cutile_fwd" )
200213 x = torch .randn (B , T , n , C , dtype = dtype , requires_grad = True , device = device )
@@ -208,6 +221,7 @@ def run_post_res_cutile(B, T, n, C, dtype, device, do_backward):
208221 out .sum ().backward ()
209222 nvtx .range_pop ()
210223
224+
211225def run_post_res_compile (B , T , n , C , dtype , device , do_backward ):
212226 nvtx .range_push ("mhc_post_res_compile_fwd" )
213227 x = torch .randn (B , T , n , C , dtype = dtype , requires_grad = True , device = device )
@@ -221,14 +235,20 @@ def run_post_res_compile(B, T, n, C, dtype, device, do_backward):
221235 out .sum ().backward ()
222236 nvtx .range_pop ()
223237
238+
224239def run_post_res (B , T , n , C , dtype , device , do_backward ):
225240 run_post_res_cutile (B , T , n , C , dtype , device , do_backward )
226241 run_post_res_triton (B , T , n , C , dtype , device , do_backward )
227242 run_post_res_compile (B , T , n , C , dtype , device , do_backward )
228243
244+
229245def main ():
230246 parser = argparse .ArgumentParser ()
231- parser .add_argument ("--operation" , choices = ["sinkhorn" , "projection" , "elementwise" , "pre" , "post_res" , "all" ], required = True )
247+ parser .add_argument (
248+ "--operation" ,
249+ choices = ["sinkhorn" , "projection" , "elementwise" , "pre" , "post_res" , "all" ],
250+ required = True ,
251+ )
232252 parser .add_argument ("--dtype" , choices = ["float32" , "bfloat16" ], default = "float32" )
233253 parser .add_argument ("--warmup" , type = int , default = 2 )
234254 parser .add_argument ("--iters" , type = int , default = 1 )
0 commit comments