@@ -27,7 +27,7 @@ def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr,
27
27
BLOCK_M: M dimension of the output block m x n
28
28
BLOCK_N: N dimension of the output block m x n
29
29
BLOCK_K: K dimension atom
30
- EVEN_K: True iff the blocks of A and B can be loaded without any
30
+ EVEN_K: True if the blocks of A and B can be loaded without any
31
31
masking.
32
32
SPLIT_K: Parameter signifying parallelism in the K dimension.
33
33
CAST_TYPE: if True, cast the values from the A matrix to the B
@@ -127,7 +127,7 @@ def do_expand_kernel(
127
127
offset_n = tl .arange (0 , BLOCK_N ) + pid_n * BLOCK_N
128
128
rbn = tl .max_contiguous (tl .multiple_of (offset_n % N , BLOCK_N ), BLOCK_N )
129
129
130
- # Identify the row pointers of A and column pointers of B
130
+ # Identify A and B block pointers
131
131
offset_k = tl .arange (0 , BLOCK_K )
132
132
a_ptr = (cur_input_ptr + ram [:, None ] * input_d1_stride +
133
133
offset_k [None , :] * input_d2_stride , )
@@ -213,7 +213,7 @@ def do_shrink_kernel(
213
213
offset_n = tl .arange (0 , BLOCK_N ) + pid_n * BLOCK_N
214
214
rbn = tl .max_contiguous (tl .multiple_of (offset_n % N , BLOCK_N ), BLOCK_N )
215
215
216
- # Identify the row pointers of A and column pointers of B
216
+ # Identify A and B block pointers
217
217
offset_k = pid_sk * BLOCK_K + tl .arange (0 , BLOCK_K )
218
218
a_ptr = (input_ptr + ram [:, None ] * input_d0_stride +
219
219
offset_k [None , :] * input_d1_stride )
0 commit comments