Skip to content

Commit a1c931c

Browse files
andrewkchanIwan Kawrakow
andauthored
Trellis quants with CPU inference (#441)
* WIP * WIP * WIP * Testing Trellis quantization Using 12 bits per 8 weights I get a better rmse than iq2_xxs. I still need to see how quantizing the group-of-8 scales will affect accuracy. By AVX2 SIMDifying the search for the best code, LLaMA-3.1-8B gets quantized in 130 seconds on the Ryzen-7950X CPU - sluggish but still acceptable. * Testing Trellis quantization: 4-bit quantized block scales rmse increases by just 3%, so this is beating iq2_xss in terms of rmse at the same 2.0625 bpw. * Testing Trellis quantization: playing with scales and generators * iq2_kt: quantize / dequantize I now see that I was comparing apples to oranges: iq2_xxs was using a weight of sigma^2/4 + x^2, while the Trellis approach wasn't (weight = 1). Once I use the same weight, iq2_kt is actually slightly worse than iq2_xxs in terms of rmse, so does not look promising at this point. Also, once each group of 8 Trellis values no longer has a constant sum(q^2) that we can precompute, quantization becomes significantly slower (476 seconds for LLaMA-3.1-8B). * iq2_kt: CUDA dequantize so we can run perplexity calcs. As already indicated by rmse, the 2-bit trellis approach is quite a bit worse than iq2_xxs. * WIP * WIP * WIP - try larger blocks With blocks of 32 and 16 bits per groups of 8 the brute force seach becomes prohibitive in terms of CPU time (30+ minutes for 8B LLaMA after SIMDifying with AVX2). The trick is to group the points in clusters, find the nearest cluster, and only search within the cluster. * iq2_kt - this is better Using blocks of 32 and 16 bits per group of 8 weights it beats iq2_xxs in terms of PPL by a significant margin. It is 0.0625 bpw larger, but even if we go to 15 bits per group od 8 (so 0.0625 bpw less than iq2_xxs), PPL is still lower. * iq2_kt - even better Re-quantize after determining block scales (at the epxense of much longer quantization time). * iq2_kt: CUDA dot product Implemented as DMMV. Very slow - just 81 t/s for LLaMA-3.1-8B. Then again, Q2_K_S with forced to use DMMV only gets 112 t/s vs 145 t/s via MMVQ. My memory is that when the DMMV kernels were properly maintained/used, DMMV was about on par with MMVQ for k-quants on my GPU. * iq2_kt: very slightly faster CUDA dot product * iq2_kt: f16 CUDA dot product We arrive at 112 t/s. * iq2_kt: faster f16 CUDA dot product We arrive at 139 t/s (no FA), and 149 t/s (FA). My RTX-4080 is ~20% slower than the RTX-6000 quoted in the QTIP repository, so with FA (which I'm sure they also used) we are at around ~180 t/s on their GPU, so almost matching their performance. * iq2_kt: faster f16 CUDA dot product We arrive at 146 t/s (no FA), and 158 t/s (FA). This is measured for LLaMA-3.1-8B with output.weight left as f16. * Minor * Adding iq3_kt 3.125 bpw. So far does not look good on the PPL vs bpw plot. * Forgotten change * WIP * WIP * iq3_kt WIP: slowly improving PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.8322, which is starting to be competitive/slightly better than other quants. * WIP * iq3_kt WIP: slowly improving PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7892 * iq3_kt WIP: slowly improving PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7689 after shrinking by 0.015 bpw by using iq4_k instead of q5_k for attn_v. * iq3_kt WIP: speed up quantization Nearly 60% improvement of quantization speed by having the points nelonging to a cluster copied to contiguous memory during initialization, and then accessed sequantially while searching for the closest point. LLaMA-3.1-8B now gets quantized in ~150 seconds on the Ryzen-5975WX. * iq3_kt speed up quantization Same trick as last commit applied to iq2_kt. Here we get an even larger speedup: quantization time on the Ryzen-5975WX for LLaMA-3.1-8B drops to 195 seconds from 375 seconds! * iq3_kt: CUDA dot product * iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.2406 PPL(LLaMA-2-7B, 4096) = 6.4179 * iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642 PPL(LLaMA-2-7B, 4096) = 6.3920 * Adding iq4_kt - not competitive at this point * WIP * WIP * iq4_kt: CUDA dot product * iq4_kt: minor tweaks * iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.1642 PPL(LLaMA-2-7B, 4096) = 6.3920 * iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 9.0297 PPL(LLaMA-2-7B, 4096) = 6.3913 Ah, quantization is faster too. About 20% faster. * iq3_kt: small improvements and faster quantization * iq2_kt: SOTA We arrive at PPL(LLaMA-3.1-8B-Instruct, 8192) = 8.9627 PPL(LLaMA-2-7B, 4096) = 6.3825 Quantization is faster too: ~200 seconds for LLaMA-3.1-8B on Ryzen-5975WX. * iq3_kt: small progress * WIP * iq4_kt: go to 4.0 bpw 15 bits per group of 4, plus 8 bit scales ifor blocks of 32. This gives a slightly better PPL than iq4_kss. * iq4_kt: very slightly better at the expense of much longer quantization time. * iq4_kt: failed attemt to adjust CUDA dot product It was working for 4.125 bpw. But after changing to 4.0 bpw there is something wrong and I don't see the bug. * DRY * DRY * iq4_kt: CUDA dot product works * DRY * Report actual bpw * Minor tweaks * Checkpoint Go to groups of 8 for iq3_kt. 2 x 8 = 16 bits for the magnitude plus 1 bpw for the sign. It goves a visible improvement in the PPL vs bpw plot, but that comes at the expense of much longer quantization time (7.5 minutes for LLaMA-3.1-8B on the Ryzen-5975WX). I also notices that the 3INST generator is not actually generating a Gaussian distribution. But going to a better generator means readjusting all the hyper-parameters, so leaving it for later. * WIP for IQ2_KT * WIP - working basic iq2_kt * still super slow (0.17t/s eval) * flatten 3inst iters + avx2 (0.3t/s eval) * iq3_kt (0.3t/s eval) and renames * wip buggy iq4_KT * fix (0.22t/s eval) * naming and remove unused fn * cleanup * more cleanup * delete unused and noncompiling mmvq functions * Some performance tweaks * Slighty faster iq2_kt * port Trellis struct to iq3_kt, iq4_kt * oops untracked files --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 3efdd6d commit a1c931c

21 files changed

+3028
-25
lines changed
Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
1+
set(ARCH_FLAGS "")
2+
if (NOT MSVC)
3+
list(APPEND ARCH_FLAGS -march=native)
4+
endif()
5+
message(STATUS "ARCH_FLAGS = ${ARCH_FLAGS}")
6+
#if (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
7+
# (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
8+
# CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
9+
# message(STATUS "x86 detected")
10+
# if (NOT MSVC)
11+
# list(APPEND ARCH_FLAGS -march=native)
12+
# endif()
13+
#endif()
14+
15+
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
116
set(TARGET llama-quantize-stats)
217
add_executable(${TARGET} quantize-stats.cpp)
318
install(TARGETS ${TARGET} RUNTIME)
419
target_link_libraries(${TARGET} PRIVATE llama build_info ${CMAKE_THREAD_LIBS_INIT})
520
target_include_directories(${TARGET} PRIVATE ../../common)
6-
target_compile_features(${TARGET} PRIVATE cxx_std_11)
21+
target_compile_features(${TARGET} PRIVATE cxx_std_17)

examples/quantize-stats/quantize-stats.cpp

Lines changed: 579 additions & 16 deletions
Large diffs are not rendered by default.

examples/quantize/quantize.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
4646
{ "Q2_K_R4", LLAMA_FTYPE_MOSTLY_Q2_K_R4, "Q2_K_S repacked", },
4747
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
4848
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", },
49+
{ "IQ3_KT", LLAMA_FTYPE_MOSTLY_IQ3_KT, " 3.125 bpw trellis quantization", },
50+
{ "IQ4_KT", LLAMA_FTYPE_MOSTLY_IQ4_KT, " 4.0 bpw trellis quantization", },
4951
{ "IQ3_XXS_R4",LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4,"IQ3_XXS repacked", },
5052
{ "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", },
5153
{ "IQ3_S_R4", LLAMA_FTYPE_MOSTLY_IQ3_S_R4, "IQ3_S repacked", },
@@ -73,6 +75,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
7375
{ "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",},
7476
{ "IQ2_K_R4", LLAMA_FTYPE_MOSTLY_IQ2_K_R4, "IQ2_K repacked",},
7577
{ "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",},
78+
{ "IQ2_KT", LLAMA_FTYPE_MOSTLY_IQ2_KT, " 2.125 bpw trellis quantization", },
7679
{ "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", },
7780
{ "IQ3_K_R4", LLAMA_FTYPE_MOSTLY_IQ3_K_R4, "IQ3_K repacked", },
7881
{ "IQ3_KL", LLAMA_FTYPE_MOSTLY_IQ3_KL, " 4 bpw non-linear quantization mix",},

ggml/include/ggml.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,9 @@ extern "C" {
426426
GGML_TYPE_Q8_K128 = 150,
427427
GGML_TYPE_Q8_KV = 151,
428428
GGML_TYPE_IQ5_KS = 152,
429+
GGML_TYPE_IQ2_KT = 153,
430+
GGML_TYPE_IQ3_KT = 154,
431+
GGML_TYPE_IQ4_KT = 155,
429432

430433
GGML_TYPE_Q4_0_R8 = 202,
431434
GGML_TYPE_Q5_0_R4 = 206,
@@ -515,6 +518,9 @@ extern "C" {
515518
GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors
516519
GGML_FTYPE_MOSTLY_Q8_KV = 140, // except 1d tensors
517520
GGML_FTYPE_MOSTLY_IQ5_KS = 141, // except 1d tensors
521+
GGML_FTYPE_MOSTLY_IQ2_KT = 142, // except 1d tensors
522+
GGML_FTYPE_MOSTLY_IQ3_KT = 143, // except 1d tensors
523+
GGML_FTYPE_MOSTLY_IQ4_KT = 144, // except 1d tensors
518524
//
519525
GGML_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors
520526
GGML_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors

ggml/src/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ if (GGML_IQK_MUL_MAT)
268268
iqk/fa/iqk_fa_64_64.cpp
269269
iqk/iqk_gemm_floats.cpp
270270
iqk/iqk_gemm_kquants.cpp
271+
iqk/iqk_gemm_ktquants.cpp
271272
iqk/iqk_gemm_iquants.cpp
272273
iqk/iqk_gemm_iqk_quants.cpp
273274
iqk/iqk_gemm_1bit.cpp
@@ -277,6 +278,7 @@ if (GGML_IQK_MUL_MAT)
277278
iqk/fa/iqk_fa_templates.h
278279
iqk/iqk_gemm_floats.h
279280
iqk/iqk_gemm_kquants.h
281+
iqk/iqk_gemm_ktquants.h
280282
iqk/iqk_gemm_iquants.h
281283
iqk/iqk_gemm_iqk_quants.h
282284
iqk/iqk_gemm_1bit.h

ggml/src/ggml-common.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,24 @@ typedef struct {
620620
} block_iq2_ks;
621621
static_assert(sizeof(block_iq2_ks) == sizeof(uint16_t) + QK_K/64 + QK_K/4, "wrong iq2_ks block size/padding");
622622

623+
typedef struct {
624+
uint8_t scales[QK_K/64];
625+
uint8_t ql[QK_K/4];
626+
} block_iq2_kt;
627+
static_assert(sizeof(block_iq2_kt) == QK_K/4 + QK_K/64, "wrong iq2_kt block size/padding");
628+
629+
typedef struct {
630+
uint8_t scales[QK_K/64];
631+
uint8_t ql[QK_K/4];
632+
uint8_t qh[QK_K/8];
633+
} block_iq3_kt;
634+
static_assert(sizeof(block_iq3_kt) == QK_K/4 + QK_K/8 + QK_K/64, "wrong iq3_kt block size/padding");
635+
636+
typedef struct {
637+
uint32_t qs[QK_K/8];
638+
} block_iq4_kt;
639+
static_assert(sizeof(block_iq4_kt) == QK_K/2, "wrong iq4_kt block size/padding");
640+
623641
typedef struct {
624642
ggml_half d;
625643
uint16_t extra;

ggml/src/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2111,6 +2111,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
21112111
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
21122112
&& src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
21132113
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
2114+
&& ggml_cuda_mmvq_type_supported(src0->type)
21142115
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
21152116
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
21162117
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
@@ -3460,6 +3461,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
34603461
case GGML_TYPE_IQ5_KS:
34613462
case GGML_TYPE_IQ2_K:
34623463
case GGML_TYPE_IQ2_KS:
3464+
case GGML_TYPE_IQ2_KT:
3465+
case GGML_TYPE_IQ3_KT:
3466+
case GGML_TYPE_IQ4_KT:
34633467
case GGML_TYPE_IQ3_K:
34643468
case GGML_TYPE_IQ4_K:
34653469
case GGML_TYPE_IQ5_K:

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KS> {
564564
static constexpr int qi = QI4_XS;
565565
};
566566

567+
template<>
568+
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KT> {
569+
static constexpr int qk = QK_K;
570+
static constexpr int qr = QR4_XS;
571+
static constexpr int qi = QI4_XS;
572+
};
573+
567574
template<>
568575
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> {
569576
static constexpr int qk = QK_K;

ggml/src/ggml-cuda/convert.cu

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,101 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
333333
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
334334
}
335335

336+
inline __device__ int nearest_int(float fval) {
337+
assert(fval <= 4194303.f);
338+
float val = fval + 12582912.f;
339+
int i; memcpy(&i, &val, sizeof(int));
340+
return (i & 0x007fffff) - 0x00400000;
341+
}
342+
343+
float __device__ __forceinline__ trellis_next(uint32_t& val) {
344+
constexpr uint32_t ka = 89226354;
345+
constexpr uint32_t kb = 64248484;
346+
constexpr uint32_t kmask = 0x8fff8fff;
347+
constexpr uint32_t km32 = 0x3b603b60;
348+
uint32_t s;
349+
const half * h = (const half *)&s;
350+
val = ka*val + kb;
351+
s = (val & kmask) ^ km32;
352+
return (float)(h[0]+h[1]);
353+
}
354+
355+
template<typename dst_t>
356+
static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
357+
358+
int64_t ii = blockIdx.x;
359+
int64_t row = (QK_K * ii) / n_per_row;
360+
const char * cx = (const char *)vx + row * row_size;
361+
float scale = *(const float *)cx;
362+
const block_iq2_kt * x = (const block_iq2_kt *)(cx + sizeof(float));
363+
const int64_t i = ii - (row*n_per_row)/QK_K;
364+
365+
const int64_t tid = threadIdx.x;
366+
const int64_t ib = tid; // 0...31
367+
dst_t * y = yy + ii*QK_K + 8*ib;
368+
const uint16_t * ql = (const uint16_t *)x[i].ql;
369+
uint32_t idx = ql[ib] + 4096;
370+
const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f;
371+
for (int j = 0; j < 8; ++j) {
372+
y[j] = dl * trellis_next(idx);
373+
}
374+
}
375+
376+
template<typename dst_t>
377+
static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
378+
379+
int64_t ii = blockIdx.x;
380+
int64_t row = (QK_K * ii) / n_per_row;
381+
const char * cx = (const char *)vx + row * row_size;
382+
float scale = *(const float *)cx;
383+
const block_iq3_kt * x = (const block_iq3_kt *)(cx + sizeof(float));
384+
const int64_t i = ii - (row*n_per_row)/QK_K;
385+
386+
const int64_t tid = threadIdx.x;
387+
const int64_t ib = tid; // 0...31
388+
dst_t * y = yy + ii*QK_K + 8*ib;
389+
const uint16_t * ql = (const uint16_t *)x[i].ql;
390+
uint32_t idx = ql[ib] + 4096;
391+
const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f;
392+
uint8_t mask = 1 << (ib/4);
393+
for (int j = 0; j < 8; ++j) {
394+
y[j] = dl * std::abs(trellis_next(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
395+
}
396+
}
397+
398+
template<typename dst_t>
399+
static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
400+
401+
int64_t ii = blockIdx.x;
402+
int64_t row = (QK_K * ii) / n_per_row;
403+
const float * dptr = (const float *)((const char *)vx + row * row_size);
404+
float scale = dptr[0] * 31.75f * 1.01f;
405+
float row_av = dptr[1];
406+
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
407+
const int64_t i = ii - (row*n_per_row)/QK_K;
408+
409+
constexpr int kNumGroups = 64;
410+
411+
const int64_t tid = threadIdx.x;
412+
const int64_t ib = tid; // 0...31
413+
dst_t * y = yy + ii*QK_K + 8*ib;
414+
const uint32_t * shb = x[i].qs;
415+
const uint8_t * ql = (const uint8_t *)(shb + 8); //Q::kNblock;
416+
const uint8_t * qh = ql + kNumGroups;
417+
const int ib32 = ib/4;
418+
const int ig = ib%4;
419+
const int jj = ib32*8 + 2*ig;
420+
uint32_t offset = shb[ib32] & 1 ? 4096 + 32768 : 4096;
421+
uint32_t idx1 = ql[jj+0] + ((qh[(jj+0)%(kNumGroups/2)] << (8 - 4*((jj+0)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+0)) & 7) << 12) + offset;
422+
uint32_t idx2 = ql[jj+1] + ((qh[(jj+1)%(kNumGroups/2)] << (8 - 4*((jj+1)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+3)) & 7) << 12) + offset;
423+
int ls = ((shb[ib32] & 0xff) >> 1) - 64;
424+
const float dl = scale * ls;
425+
for (int j = 0; j < 4; ++j) {
426+
y[j+0] = dl * trellis_next(idx1) + row_av;
427+
y[j+4] = dl * trellis_next(idx2) + row_av;
428+
}
429+
}
430+
336431
template<typename dst_t>
337432
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
338433

@@ -968,6 +1063,27 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_
9681063
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
9691064
}
9701065

1066+
template<typename dst_t>
1067+
static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1068+
const int64_t k = nrows * n_per_row;
1069+
const int nb = k / QK_K;
1070+
dequantize_block_iq2_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row));
1071+
}
1072+
1073+
template<typename dst_t>
1074+
static void dequantize_row_iq3_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1075+
const int64_t k = nrows * n_per_row;
1076+
const int nb = k / QK_K;
1077+
dequantize_block_iq3_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ3_KT, n_per_row));
1078+
}
1079+
1080+
template<typename dst_t>
1081+
static void dequantize_row_iq4_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1082+
const int64_t k = nrows * n_per_row;
1083+
const int nb = k / QK_K;
1084+
dequantize_block_iq4_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ4_KT, n_per_row));
1085+
}
1086+
9711087
template<typename dst_t>
9721088
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
9731089
const int64_t k = nrows * n_per_row;
@@ -1230,6 +1346,12 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
12301346
return dequantize_row_q6_K_cuda;
12311347
case GGML_TYPE_IQ2_XXS:
12321348
return dequantize_row_iq2_xxs_cuda;
1349+
case GGML_TYPE_IQ2_KT:
1350+
return dequantize_row_iq2_kt_cuda;
1351+
case GGML_TYPE_IQ3_KT:
1352+
return dequantize_row_iq3_kt_cuda;
1353+
case GGML_TYPE_IQ4_KT:
1354+
return dequantize_row_iq4_kt_cuda;
12331355
case GGML_TYPE_IQ2_XS:
12341356
return dequantize_row_iq2_xs_cuda;
12351357
case GGML_TYPE_IQ2_S:
@@ -1303,6 +1425,12 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
13031425
return dequantize_row_q6_K_cuda;
13041426
case GGML_TYPE_IQ2_XXS:
13051427
return dequantize_row_iq2_xxs_cuda;
1428+
case GGML_TYPE_IQ2_KT:
1429+
return dequantize_row_iq2_kt_cuda;
1430+
case GGML_TYPE_IQ3_KT:
1431+
return dequantize_row_iq3_kt_cuda;
1432+
case GGML_TYPE_IQ4_KT:
1433+
return dequantize_row_iq4_kt_cuda;
13061434
case GGML_TYPE_IQ2_XS:
13071435
return dequantize_row_iq2_xs_cuda;
13081436
case GGML_TYPE_IQ2_S:

0 commit comments

Comments
 (0)