Skip to content

Commit 7d30c42

Browse files
committed
ggml-cpu : add RVV support for f32 softmax
1 parent 2231433 commit 7d30c42

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

ggml/src/ggml-cpu/vec.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float
335335
vst1q_f32(y + i, val);
336336
sum += (ggml_float)vaddvq_f32(val);
337337
}
338+
#elif defined(__riscv_v_intrinsic)
339+
vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
340+
for (int avl; i < n; i += avl) {
341+
avl = __riscv_vsetvl_e32m2(n - i);
342+
vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl);
343+
__riscv_vse32_v_f32m2(&y[i], val, avl);
344+
vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl);
345+
}
346+
return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
338347
#endif
339348
for (; i < n; ++i) {
340349
float val = expf(x[i] - max);

ggml/src/ggml-cpu/vec.h

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,51 @@ inline static __m128 ggml_v_silu(__m128 x) {
986986
return _mm_div_ps(x, one_plus_exp_neg_x);
987987
}
988988

989-
#endif // __ARM_NEON / __AVX2__ / __SSE2__
989+
#elif defined(__riscv_v_intrinsic)
990+
991+
// adapted from arm limited optimized routine
992+
// the maximum error is 1.45358 plus 0.5 ulps
993+
// numbers above 88.38 will flush to infinity
994+
// numbers beneath -103.97 will flush to zero
995+
inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {
996+
const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl);
997+
#ifdef __riscv_xtheadvector
998+
// workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4')
999+
vfloat32m2_t z = __riscv_vfadd_vf_f32m2(r, 0.0f, vl);
1000+
z = __riscv_vfmacc_vf_f32m2(z, 0x1.715476p+0f, x, vl);
1001+
#else
1002+
const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl);
1003+
#endif
1004+
const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl);
1005+
const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl),
1006+
0x1.7f7d1cp-20f, n, vl);
1007+
const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl);
1008+
const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); // 1.0f
1009+
const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl);
1010+
const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl);
1011+
const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2(
1012+
__riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl),
1013+
__riscv_vfmacc_vv_f32m2(
1014+
__riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl),
1015+
__riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl),
1016+
u, vl), u, vl);
1017+
if (!__riscv_vcpop_m_b16(c, vl))
1018+
return __riscv_vfmacc_vv_f32m2(k, j, k, vl);
1019+
const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl);
1020+
const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl);
1021+
const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl));
1022+
const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl));
1023+
const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2(
1024+
__riscv_vfmacc_vv_f32m2(k, k, j, vl),
1025+
__riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl),
1026+
c, vl);
1027+
return __riscv_vmerge_vvm_f32m2(
1028+
r1, __riscv_vfmul_vv_f32m2(s1, s1, vl),
1029+
__riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl),
1030+
vl);
1031+
}
1032+
1033+
#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
9901034

9911035
inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
9921036
for (int i = 0; i < n; ++i) {

0 commit comments

Comments
 (0)