Skip to content

Commit 856ed09

Browse files
authored
metal : Add template specialization for mul_mm_id w/ ne20 == 10 (#15799)
Branch: GGMLMetalNE20 Signed-off-by: Gabe Goodhart <[email protected]>
1 parent d1e2adb commit 856ed09

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
407407
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
408408
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
409409
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
410+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10,
410411
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
411412
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
412413
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
@@ -1439,6 +1440,7 @@ @implementation GGMLMetalClass
14391440
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
14401441
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
14411442
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
1443+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10, mul_mm_id_map0_f16_ne20_10, has_simdgroup_mm);
14421444
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
14431445
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
14441446
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
@@ -3979,6 +3981,7 @@ static int ggml_metal_encode_node(
39793981
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
39803982
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
39813983
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
3984+
case 10: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10].pipeline; break;
39823985
case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
39833986
default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
39843987
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7618,6 +7618,7 @@ template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm
76187618
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
76197619
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
76207620
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
7621+
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
76217622
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
76227623

76237624
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>

0 commit comments

Comments
 (0)