diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index ce0a3e1285eb0..b70302ec8c9b1 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -435,7 +435,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ) if (GGML_RVV) if (GGML_XTHEADVECTOR) - list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d) + list(APPEND ARCH_FLAGS -march=rv64gc_zfhmin_xtheadvector -mabi=lp64d) elseif (GGML_RV_ZFH) list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d) else() diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 460367cca09e9..93330b43a9b84 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9072,6 +9072,9 @@ static void ggml_compute_forward_ssm_scan_f32( } sumf = GGML_F32xt_REDUCE_ONE(sum); + #elif defined(__riscv_v_intrinsic) + // todo: RVV implementation + const int np = 0; #else const int np = (nc & ~(GGML_F32_STEP - 1)); @@ -10023,8 +10026,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32( int64_t h_stride_2d = head_size * head_size; #if defined(GGML_SIMD) - #if defined(__ARM_FEATURE_SVE) - // scalar Route to scalar implementation //TODO: Write SVE code + #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic) + // scalar Route to scalar implementation //TODO: Write SVE code and RVV code for (int64_t t = 0; t < T; t++) { int64_t t_offset = t * t_stride; int64_t state_offset = head_size * C * (t / (T / n_seqs)); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index b4ad68c9fd647..f71ce580799f6 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -18,6 +18,10 @@ #include #endif +#if defined(__riscv_v_intrinsic) +#include +#endif + #ifdef __cplusplus extern "C" { #endif @@ -94,24 +98,15 @@ extern "C" { } #elif defined(__riscv) && defined(__riscv_zfhmin) static inline float riscv_compute_fp16_to_fp32(ggml_fp16_t h) { - float f; - __asm__( - "fmv.h.x %[f], %[h]\n\t" - "fcvt.s.h %[f], %[f]" - : [f] "=&f" (f) - : [h] "r" (h) - ); - return f; + _Float16 hf; + memcpy(&hf, &h, sizeof(ggml_fp16_t)); + return hf; } static inline ggml_fp16_t riscv_compute_fp32_to_fp16(float f) { ggml_fp16_t res; - __asm__( - "fcvt.h.s %[f], %[f]\n\t" - "fmv.x.h %[h], %[f]" - : [h] "=&r" (res) - : [f] "f" (f) - ); + _Float16 hf = (_Float16)f; + memcpy(&res, &hf, sizeof(ggml_fp16_t)); return res; } @@ -1170,6 +1165,36 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { #define GGML_F16_VEC_MUL GGML_F32x4_MUL #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE +#elif defined(__riscv_v_intrinsic) + +// compatible with vlen >= 128 + +#define GGML_SIMD + +// F32 + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 vfloat32m1_t +#define GGML_F32x4_ZERO __riscv_vfmv_v_f_f32m1(0.0f, GGML_F32_EPR) +#define GGML_F32x4_SET1(x) __riscv_vfmv_v_f_f32m1(x, GGML_F32_EPR) +#define GGML_F32x4_LOAD(x) __riscv_vle32_v_f32m1(x, GGML_F32_EPR) +#define GGML_F32x4_STORE(b, v) __riscv_vse32_v_f32m1(b, v, GGML_F32_EPR) +#define GGML_F32x4_FMA(a, b, c) __riscv_vfmacc_vv_f32m1(a, b, c, GGML_F32_EPR) +#define GGML_F32x4_ADD(a, b) __riscv_vfadd_vv_f32m1(a, b, GGML_F32_EPR) +#define GGML_F32x4_MUL(a, b) __riscv_vfmul_vv_f32m1(a, b, GGML_F32_EPR) + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + #endif // GGML_F32_ARR / GGML_F16_ARR diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 07b377bdd82a7..d8ec3b81d2446 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -84,6 +84,16 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G } // reduce sum1,sum2 to sum1 GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8); + #elif defined(__riscv_v_intrinsic) + vfloat32m1_t vsum = __riscv_vfmv_v_f_f32m1(0.0f, 1); + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl); + vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); + vfloat32m8_t prod = __riscv_vfmul_vv_f32m8(ax, ay, avl); + vsum = __riscv_vfredusum_vs_f32m8_f32m1(prod, vsum, avl); + } + sumf += __riscv_vfmv_f_s_f32m1_f32(vsum); #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -197,7 +207,7 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G ggml_float sumf = 0.0; -#if defined(GGML_SIMD) +#if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic) const int np = (n & ~(GGML_F16_STEP - 1)); GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; @@ -325,6 +335,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float vst1q_f32(y + i, val); sum += (ggml_float)vaddvq_f32(val); } +#elif defined(__riscv_v_intrinsic) + vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1); + for (int avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m2(n - i); + vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl); + __riscv_vse32_v_f32m2(&y[i], val, avl); + vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl); + } + return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum); #endif for (; i < n; ++i) { float val = expf(x[i] - max); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 2250d93cb00d1..8ccf340d472ad 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -119,6 +119,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG } #if defined(GGML_SIMD) +#if defined(__riscv_v_intrinsic) + // todo: RVV impl + for (int i = 0; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); + } + } +#else const int np = (n & ~(GGML_F16_STEP - 1)); GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; @@ -149,6 +157,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); } } +#endif #else for (int i = 0; i < n; ++i) { for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { @@ -243,6 +252,14 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const svst1_f32(pg, y + np2, ay1); } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl); + vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); + vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, v, ay, avl); + __riscv_vse32_v_f32m8(&y[i], ny, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -276,6 +293,13 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) { #if defined(GGML_SIMD) +#if defined(__riscv_v_intrinsic) + // todo: RVV impl + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); + } +#else const int np = (n & ~(GGML_F16_STEP - 1)); GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); @@ -297,6 +321,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); } +#endif #else // scalar for (int i = 0; i < n; ++i) { @@ -324,6 +349,16 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int y[i] += x[k][i]*v[k][0]; } } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); + for (int k = 0; k < GGML_VEC_MAD_UNROLL; k++) { + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[k][i], avl); + ay = __riscv_vfmadd_vf_f32m8(ax, v[k][0], ay, avl); + } + __riscv_vse32_v_f32m8(&y[i], ay, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -375,6 +410,14 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co for (int i = 0; i < n; ++i) { y[i] = x[i]*s + b; } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl); + vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl); + vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl); + __riscv_vse32_v_f32m8(&y[i], ny, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -436,6 +479,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { ay1 = svmul_f32_m(pg, ay1, vx); svst1_f32(pg, y + np, ay1); } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); + vfloat32m8_t ny = __riscv_vfmul_vf_f32m8(ay, v, avl); + __riscv_vse32_v_f32m8(&y[i], ny, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -467,6 +517,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { #if defined(GGML_SIMD) +#if defined(__riscv_v_intrinsic) + // todo: RVV impl + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); + } +#else const int np = (n & ~(GGML_F16_STEP - 1)); GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); @@ -486,6 +543,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); } +#endif #else // scalar for (int i = 0; i < n; ++i) { @@ -928,7 +986,51 @@ inline static __m128 ggml_v_silu(__m128 x) { return _mm_div_ps(x, one_plus_exp_neg_x); } -#endif // __ARM_NEON / __AVX2__ / __SSE2__ +#elif defined(__riscv_v_intrinsic) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) { + const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl); +#ifdef __riscv_xtheadvector + // workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4') + vfloat32m2_t z = __riscv_vfadd_vf_f32m2(r, 0.0f, vl); + z = __riscv_vfmacc_vf_f32m2(z, 0x1.715476p+0f, x, vl); +#else + const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl); +#endif + const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl); + const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl), + 0x1.7f7d1cp-20f, n, vl); + const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl); + const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); // 1.0f + const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl); + const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl); + const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2( + __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl), + __riscv_vfmacc_vv_f32m2( + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl), + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl), + u, vl), u, vl); + if (!__riscv_vcpop_m_b16(c, vl)) + return __riscv_vfmacc_vv_f32m2(k, j, k, vl); + const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl); + const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl); + const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl)); + const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl)); + const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2( + __riscv_vfmacc_vv_f32m2(k, k, j, vl), + __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl), + c, vl); + return __riscv_vmerge_vvm_f32m2( + r1, __riscv_vfmul_vv_f32m2(s1, s1, vl), + __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl), + vl); +} + +#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) {