Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 84 additions & 94 deletions ggml/src/iqk/iqk_gemm_ktquants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ struct Trellis2 {
const __m256i mask2 = _mm256_set1_epi32(km32);

inline __m256i next8(uint32_t val1, uint32_t val2) {
__m256i mval = _mm256_setr_epi32(val1, val1, val1, val1, val2, val2, val2, val2);
__m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1));
//__m256i mval = _mm256_setr_epi32(val1, val1, val1, val1, val2, val2, val2, val2);
__m256i mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb);
return _mm256_and_si256(mres, _mm256_set1_epi32(kmask)) ^ _mm256_set1_epi32(km32);
}
Expand Down Expand Up @@ -189,35 +190,21 @@ static inline __m256 abs_ps(__m256 vals) {
return _mm256_andnot_ps(sign_bit, vals);
}

// Negates 32-bit float lanes of an 8x32-bit vector
// based on 8x8-bit condition var. For float lane i, if byte i of
// `condition` is nonzero, the float will be negated.
static inline __m256 conditional_negate_ps(__m256 vals, uint64_t condition_mask_u64) {
__m128i condition_bytes = _mm_set_epi64x(0, condition_mask_u64);
// Make `should_negate_byte_mask` where byte i == 0xFF if byte i in condition_bytes is zero,
// else 0x00 (upper bytes are meaningless)
__m128i zeros = _mm_setzero_si128();
__m128i is_zero_byte_mask = _mm_cmpeq_epi8(condition_bytes, zeros);
__m128i should_negate_byte_mask = _mm_cmpeq_epi8(is_zero_byte_mask, zeros);
// Widen lower 8x8 bits of `should_negate_byte_mask` to 8x32 bits by padding zeros
// expanded_mask_epi32[j] will be 0x000000FF if vals[j] should be negated, zero otherwise
__m256i expanded_mask_epi32 = _mm256_cvtepu8_epi32(should_negate_byte_mask);
// Same as above but with all 32 bits of lane j set if vals[j] should be negated (use to make XOR mask)
__m256i full_dword_negate_mask = _mm256_cmpgt_epi32(expanded_mask_epi32, _mm256_setzero_si256());
// Negate via XOR on sign bits of each 32-bit float
__m256i sign_bit_pattern = _mm256_set1_epi32(0x80000000); // MSB set for a 32-bit value
__m256i xor_mask_epi32 = _mm256_and_si256(full_dword_negate_mask, sign_bit_pattern);
__m256 xor_mask_ps = _mm256_castsi256_ps(xor_mask_epi32);
return _mm256_xor_ps(vals, xor_mask_ps);
}

template <int nrc_y>
static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
const int nb = n/QK_K;

Trellis1 trellis;

union { __m256 vec; float val[8]; } s_helper;

auto shifts = _mm_set_epi32(0, 0, 4, 0);

__m256i all_signs[4];
auto mask1 = _mm256_set1_epi32(0x01);
auto mask2 = _mm256_set1_epi32(0x10);

__m256 accd[nrc_y];
const float * y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
Expand All @@ -232,31 +219,28 @@ static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataIn
for (int i = 0; i < nb; ++i) {
const uint16_t * ql = (const uint16_t *)x[i].ql;
const uint8_t * qh = x[i].qh;
for (int j = 0; j < 128; j+=8) {
uint64_t mask1 = 0x0101010101010101 << (j/32);
uint64_t mask2 = mask1 << 4;
uint32_t val1 = ql[j/8] + 4096;
uint32_t val2 = ql[j/8+16] + 4096;
const uint64_t signs = *((const uint64_t *)(qh + (j%32)));
const float x_scale1 = (x[i].scales[j/32] & 0xf);
const float x_scale2 = (x[i].scales[j/32] >> 4);
const __m256 x_val1 = abs_ps(trellis_gen8(trellis.next8(val1)));
const __m256 x_val2 = abs_ps(trellis_gen8(trellis.next8(val2)));
for (int iy = 0; iy < nrc_y; ++iy) {
accd[iy] = _mm256_fmadd_ps(
conditional_negate_ps(
_mm256_load_ps(y[iy] + i*QK_K+j), signs & mask1
),
_mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
accd[iy]
);
accd[iy] = _mm256_fmadd_ps(
conditional_negate_ps(
_mm256_load_ps(y[iy] + i*QK_K+j+128), signs & mask2
),
_mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
accd[iy]
);
auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales);
s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf));
auto s32 = _mm256_cvtepi8_epi32(s8);
s_helper.vec = _mm256_cvtepi32_ps(s32);
for (int j = 0; j < 4; ++j) all_signs[j] = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qh + 8*j)));
for (int ib = 0; ib < 4; ++ib) {
auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]);
auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]);
for (int j = 0; j < 4; ++j) {
uint32_t val1 = ql[4*ib+j ] + 4096;
uint32_t val2 = ql[4*ib+j+16] + 4096;
auto sign1 = _mm256_and_si256(_mm256_cmpeq_epi32(_mm256_and_si256(all_signs[j], mask1), mask1), _mm256_set1_epi32(0x80000000));
auto sign2 = _mm256_and_si256(_mm256_cmpeq_epi32(_mm256_and_si256(all_signs[j], mask2), mask2), _mm256_set1_epi32(0x80000000));
all_signs[j] = _mm256_srli_epi32(all_signs[j], 1);
auto x_val1 = abs_ps(trellis_gen8(trellis.next8(val1)));
auto x_val2 = abs_ps(trellis_gen8(trellis.next8(val2)));
x_val1 = _mm256_mul_ps(scale1, _mm256_xor_ps(x_val1, _mm256_castsi256_ps(sign1)));
x_val2 = _mm256_mul_ps(scale2, _mm256_xor_ps(x_val2, _mm256_castsi256_ps(sign2)));
for (int iy = 0; iy < nrc_y; ++iy) {
accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j ), x_val1, accd[iy]);
accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+128), x_val2, accd[iy]);
}
}
}
}
Expand All @@ -276,66 +260,72 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn

Trellis2 trellis;

__m256 accd[nrc_y];
__m256 accd2[nrc_y];
union { __m256 vec; float val[8]; } s_helper;
union { __m256i vec; uint32_t val[8]; } o_helper;

constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;

__m256 accd[k_acc];
const float * y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
float row_sum[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
y[iy] = (const float *)info.src1_row(iy);
auto sum = _mm256_setzero_ps();
for (int i = 0; i < n/8; ++i) sum = _mm256_add_ps(sum, _mm256_loadu_ps(y[iy] + 8*i));
row_sum[iy] = hsum_float_8(sum);
}

for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char*)vx + ix*bx);
const float d = dptr[0] * 31.75f * 1.01f;
const float row_av = dptr[1];
auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f);
auto dav = dptr[1];
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);

for (int iy = 0; iy < nrc_y; ++iy) {
accd[iy] = _mm256_setzero_ps();
accd2[iy] = _mm256_setzero_ps();
}
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = _mm256_setzero_ps();

for (int i = 0; i < nb; ++i) {
auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs);
const uint32_t * shb = x[i].qs;
const uint8_t * ql = (const uint8_t *)(shb + 8);
const uint8_t * qh = ql + kNumGroups;
for (int j = 0; j < 128; j+=8) {
const uint32_t offset1 = 4096 + ((shb[j/32+0] & 1) << 15);
const uint32_t offset2 = 4096 + ((shb[j/32+4] & 1) << 15);
const float x_scale1 = (int)((shb[j/32+0] & 0xff) >> 1) - 64;
const float x_scale2 = (int)((shb[j/32+4] & 0xff) >> 1) - 64;
const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4));
const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4));
uint32_t val1 = ql[j/4+ 0] + ((qh[j/4+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1;
uint32_t val2 = ql[j/4+32] + ((qh[j/4+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2;
uint32_t val3 = ql[j/4+ 1] + ((qh[j/4+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1;
uint32_t val4 = ql[j/4+33] + ((qh[j/4+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2;
const __m256 x_val1 = trellis_gen8(trellis.next8(val1, val3));
const __m256 x_val2 = trellis_gen8(trellis.next8(val2, val4));
for (int iy = 0; iy < nrc_y; ++iy) {
accd[iy] = _mm256_fmadd_ps(
_mm256_load_ps(y[iy] + i*QK_K+j),
_mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1),
accd[iy]
);
accd[iy] = _mm256_fmadd_ps(
_mm256_load_ps(y[iy] + i*QK_K+j+128),
_mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2),
accd[iy]
);
accd2[iy] = _mm256_add_ps(
_mm256_load_ps(y[iy] + i*QK_K+j),
accd2[iy]
);
accd2[iy] = _mm256_add_ps(
_mm256_load_ps(y[iy] + i*QK_K+j+128),
accd2[iy]
);
auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1);
s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64))));
o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096));
for (int ib = 0; ib < 4; ++ib) {
auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]);
auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]);
for (int j = 0; j < 4; ++j) {
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
auto x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1, val3)));
auto x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2, val4)));
if constexpr (nrc_y == 1) {
auto y1 = _mm256_load_ps(y[0] + i*QK_K+32*ib+8*j+ 0);
auto y2 = _mm256_load_ps(y[0] + i*QK_K+32*ib+8*j+128);
accd[0] = _mm256_fmadd_ps(y1, x_val1, accd[0]);
accd[1] = _mm256_fmadd_ps(y2, x_val2, accd[1]);
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
auto y1 = _mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+ 0);
auto y2 = _mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+128);
accd[iy] = _mm256_fmadd_ps(y1, x_val1, accd[iy]);
accd[iy] = _mm256_fmadd_ps(y2, x_val2, accd[iy]);
}
}
}
}
}

for (int iy = 0; iy < nrc_y; ++iy) {
__m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]);
__m256 res2 = _mm256_mul_ps(_mm256_set1_ps(row_av), accd2[iy]);
info.store(ix, iy, hsum_float_8(res) + hsum_float_8(res2));
if constexpr (nrc_y == 1) {
info.store(ix, 0, hsum_float_8(_mm256_add_ps(accd[0], accd[1])) + dav*row_sum[0]);
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]) + dav*row_sum[iy]);
}
}
}
}
Expand Down Expand Up @@ -400,4 +390,4 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat

#endif

#endif
#endif