Skip to content

Commit 73e07ad

Browse files
committed
vulkan: mul_mat_id coopmat2 optimizations
Add a path for when the tile fits in BN/2, similar to what we have for mul_mat. Only call fetch_scales/store_scales once per QUANT_K block, and once at the beginning in case start_k is not aligned.
1 parent 792b44f commit 73e07ad

File tree

2 files changed

+52
-6
lines changed

2 files changed

+52
-6
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2225,7 +2225,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
22252225
s_mmq_wg_denoms_k = { 32, 64, 1 };
22262226

22272227
// spec constants and tile sizes for quant matmul_id
2228-
l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size };
2228+
l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size };
22292229
m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
22302230
s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
22312231
l_mmqid_wg_denoms = { 128, 128, 1 };

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -444,18 +444,64 @@ void main() {
444444

445445
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
446446

447-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
448-
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
449-
450447
uint k_iters = (end_k - start_k + BK - 1) / BK;
451448

452449
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
450+
store_scales(tid);
451+
452+
#ifdef MUL_MAT_ID
453+
if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {
454+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum;
455+
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
456+
457+
[[dont_unroll]]
458+
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
459+
460+
if ((block_k % QUANT_K) == 0) {
461+
store_scales(tid);
462+
}
463+
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
464+
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
465+
}
466+
467+
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
468+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
469+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
470+
471+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
472+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
473+
474+
sum = coopMatMulAdd(mat_a, mat_b, sum);
475+
} else {
476+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
477+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
478+
479+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
480+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
481+
482+
sum = coopMatMulAdd(mat_a, mat_b, sum);
483+
}
484+
}
485+
486+
// Convert from ACC_TYPE to D_TYPE
487+
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d;
488+
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
489+
490+
// Call callback to store each element, remapping row through shared memory
491+
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
492+
return;
493+
}
494+
#endif
495+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
496+
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
453497

454498
[[dont_unroll]]
455499
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
456500

457-
store_scales(tid);
458-
if (block_k + BK < end_k) {
501+
if ((block_k % QUANT_K) == 0) {
502+
store_scales(tid);
503+
}
504+
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
459505
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
460506
}
461507

0 commit comments

Comments
 (0)