Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
)
if (GGML_RVV)
if (GGML_XTHEADVECTOR)
list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d)
list(APPEND ARCH_FLAGS -march=rv64gc_zfhmin_xtheadvector -mabi=lp64d)
elseif (GGML_RV_ZFH)
list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d)
else()
Expand Down
7 changes: 5 additions & 2 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9072,6 +9072,9 @@ static void ggml_compute_forward_ssm_scan_f32(
}

sumf = GGML_F32xt_REDUCE_ONE(sum);
#elif defined(__riscv_v_intrinsic)
// todo: RVV implementation
const int np = 0;
#else
const int np = (nc & ~(GGML_F32_STEP - 1));

Expand Down Expand Up @@ -10023,8 +10026,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
int64_t h_stride_2d = head_size * head_size;

#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
// scalar Route to scalar implementation //TODO: Write SVE code
#if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
// scalar Route to scalar implementation //TODO: Write SVE code and RVV code
for (int64_t t = 0; t < T; t++) {
int64_t t_offset = t * t_stride;
int64_t state_offset = head_size * C * (t / (T / n_seqs));
Expand Down
53 changes: 39 additions & 14 deletions ggml/src/ggml-cpu/simd-mappings.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
#include <immintrin.h>
#endif

#if defined(__riscv_v_intrinsic)
#include <riscv_vector.h>
#endif

#ifdef __cplusplus
extern "C" {
#endif
Expand Down Expand Up @@ -94,24 +98,15 @@ extern "C" {
}
#elif defined(__riscv) && defined(__riscv_zfhmin)
static inline float riscv_compute_fp16_to_fp32(ggml_fp16_t h) {
float f;
__asm__(
"fmv.h.x %[f], %[h]\n\t"
"fcvt.s.h %[f], %[f]"
: [f] "=&f" (f)
: [h] "r" (h)
);
return f;
_Float16 hf;
memcpy(&hf, &h, sizeof(ggml_fp16_t));
return hf;
}

static inline ggml_fp16_t riscv_compute_fp32_to_fp16(float f) {
ggml_fp16_t res;
__asm__(
"fcvt.h.s %[f], %[f]\n\t"
"fmv.x.h %[h], %[f]"
: [h] "=&r" (res)
: [f] "f" (f)
);
_Float16 hf = (_Float16)f;
memcpy(&res, &hf, sizeof(ggml_fp16_t));
return res;
}

Expand Down Expand Up @@ -1170,6 +1165,36 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE

#elif defined(__riscv_v_intrinsic)

// compatible with vlen >= 128

#define GGML_SIMD

// F32

#define GGML_F32_STEP 16
#define GGML_F32_EPR 4

#define GGML_F32x4 vfloat32m1_t
#define GGML_F32x4_ZERO __riscv_vfmv_v_f_f32m1(0.0f, GGML_F32_EPR)
#define GGML_F32x4_SET1(x) __riscv_vfmv_v_f_f32m1(x, GGML_F32_EPR)
#define GGML_F32x4_LOAD(x) __riscv_vle32_v_f32m1(x, GGML_F32_EPR)
#define GGML_F32x4_STORE(b, v) __riscv_vse32_v_f32m1(b, v, GGML_F32_EPR)
#define GGML_F32x4_FMA(a, b, c) __riscv_vfmacc_vv_f32m1(a, b, c, GGML_F32_EPR)
#define GGML_F32x4_ADD(a, b) __riscv_vfadd_vv_f32m1(a, b, GGML_F32_EPR)
#define GGML_F32x4_MUL(a, b) __riscv_vfmul_vv_f32m1(a, b, GGML_F32_EPR)

#define GGML_F32_VEC GGML_F32x4
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE

#endif

// GGML_F32_ARR / GGML_F16_ARR
Expand Down
21 changes: 20 additions & 1 deletion ggml/src/ggml-cpu/vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
}
// reduce sum1,sum2 to sum1
GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
#elif defined(__riscv_v_intrinsic)
vfloat32m1_t vsum = __riscv_vfmv_v_f_f32m1(0.0f, 1);
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
vfloat32m8_t prod = __riscv_vfmul_vv_f32m8(ax, ay, avl);
vsum = __riscv_vfredusum_vs_f32m8_f32m1(prod, vsum, avl);
}
sumf += __riscv_vfmv_f_s_f32m1_f32(vsum);
#else
const int np = (n & ~(GGML_F32_STEP - 1));

Expand Down Expand Up @@ -197,7 +207,7 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G

ggml_float sumf = 0.0;

#if defined(GGML_SIMD)
#if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
const int np = (n & ~(GGML_F16_STEP - 1));

GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
Expand Down Expand Up @@ -325,6 +335,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float
vst1q_f32(y + i, val);
sum += (ggml_float)vaddvq_f32(val);
}
#elif defined(__riscv_v_intrinsic)
vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
for (int avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m2(n - i);
vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl);
__riscv_vse32_v_f32m2(&y[i], val, avl);
vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl);
}
return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
#endif
for (; i < n; ++i) {
float val = expf(x[i] - max);
Expand Down
104 changes: 103 additions & 1 deletion ggml/src/ggml-cpu/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
}

#if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic)
// todo: RVV impl
for (int i = 0; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));

GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
Expand Down Expand Up @@ -149,6 +157,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
#endif
#else
for (int i = 0; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
Expand Down Expand Up @@ -243,6 +252,14 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const

svst1_f32(pg, y + np2, ay1);
}
#elif defined(__riscv_v_intrinsic)
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, v, ay, avl);
__riscv_vse32_v_f32m8(&y[i], ny, avl);
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));

Expand Down Expand Up @@ -276,6 +293,13 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const

inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
#if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));

GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
Expand All @@ -297,6 +321,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
Expand Down Expand Up @@ -324,6 +349,16 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
y[i] += x[k][i]*v[k][0];
}
}
#elif defined(__riscv_v_intrinsic)
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
for (int k = 0; k < GGML_VEC_MAD_UNROLL; k++) {
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[k][i], avl);
ay = __riscv_vfmadd_vf_f32m8(ax, v[k][0], ay, avl);
}
__riscv_vse32_v_f32m8(&y[i], ay, avl);
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));

Expand Down Expand Up @@ -375,6 +410,14 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co
for (int i = 0; i < n; ++i) {
y[i] = x[i]*s + b;
}
#elif defined(__riscv_v_intrinsic)
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl);
vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl);
__riscv_vse32_v_f32m8(&y[i], ny, avl);
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));

Expand Down Expand Up @@ -436,6 +479,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
ay1 = svmul_f32_m(pg, ay1, vx);
svst1_f32(pg, y + np, ay1);
}
#elif defined(__riscv_v_intrinsic)
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
vfloat32m8_t ny = __riscv_vfmul_vf_f32m8(ay, v, avl);
__riscv_vse32_v_f32m8(&y[i], ny, avl);
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));

Expand Down Expand Up @@ -467,6 +517,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {

inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
#if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));

GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
Expand All @@ -486,6 +543,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
Expand Down Expand Up @@ -928,7 +986,51 @@ inline static __m128 ggml_v_silu(__m128 x) {
return _mm_div_ps(x, one_plus_exp_neg_x);
}

#endif // __ARM_NEON / __AVX2__ / __SSE2__
#elif defined(__riscv_v_intrinsic)

// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
// numbers beneath -103.97 will flush to zero
inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {
const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl);
#ifdef __riscv_xtheadvector
// workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4')
vfloat32m2_t z = __riscv_vfadd_vf_f32m2(r, 0.0f, vl);
z = __riscv_vfmacc_vf_f32m2(z, 0x1.715476p+0f, x, vl);
#else
const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl);
#endif
const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl);
const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl),
0x1.7f7d1cp-20f, n, vl);
const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl);
const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); // 1.0f
const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl);
const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl);
const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2(
__riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl),
__riscv_vfmacc_vv_f32m2(
__riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl),
__riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl),
u, vl), u, vl);
if (!__riscv_vcpop_m_b16(c, vl))
return __riscv_vfmacc_vv_f32m2(k, j, k, vl);
const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl);
const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl);
const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl));
const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl));
const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2(
__riscv_vfmacc_vv_f32m2(k, k, j, vl),
__riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl),
c, vl);
return __riscv_vmerge_vvm_f32m2(
r1, __riscv_vfmul_vv_f32m2(s1, s1, vl),
__riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl),
vl);
}

#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic

inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
for (int i = 0; i < n; ++i) {
Expand Down
Loading