@@ -986,7 +986,51 @@ inline static __m128 ggml_v_silu(__m128 x) {
986
986
return _mm_div_ps (x , one_plus_exp_neg_x );
987
987
}
988
988
989
- #endif // __ARM_NEON / __AVX2__ / __SSE2__
989
+ #elif defined(__riscv_v_intrinsic )
990
+
991
+ // adapted from arm limited optimized routine
992
+ // the maximum error is 1.45358 plus 0.5 ulps
993
+ // numbers above 88.38 will flush to infinity
994
+ // numbers beneath -103.97 will flush to zero
995
+ inline static vfloat32m2_t ggml_v_expf_m2 (vfloat32m2_t x , int vl ) {
996
+ const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2 (0x1.8p23f , vl );
997
+ #ifdef __riscv_xtheadvector
998
+ // workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4')
999
+ vfloat32m2_t z = __riscv_vfadd_vf_f32m2 (r , 0.0f , vl );
1000
+ z = __riscv_vfmacc_vf_f32m2 (z , 0x1.715476p+0f , x , vl );
1001
+ #else
1002
+ const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2 (r , 0x1.715476p+0f , x , vl );
1003
+ #endif
1004
+ const vfloat32m2_t n = __riscv_vfsub_vv_f32m2 (z , r , vl );
1005
+ const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2 (__riscv_vfnmsac_vf_f32m2 (x , 0x1.62e4p-1f , n , vl ),
1006
+ 0x1.7f7d1cp-20f , n , vl );
1007
+ const vuint32m2_t e = __riscv_vsll_vx_u32m2 (__riscv_vreinterpret_v_f32m2_u32m2 (z ), 23 , vl );
1008
+ const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2 (__riscv_vadd_vx_u32m2 (e , 0x3f800000 , vl )); // 1.0f
1009
+ const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16 (__riscv_vfabs_v_f32m2 (n , vl ), 126.0f , vl );
1010
+ const vfloat32m2_t u = __riscv_vfmul_vv_f32m2 (b , b , vl );
1011
+ const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2 (
1012
+ __riscv_vfmul_vf_f32m2 (b , 0x1.ffffecp-1f , vl ),
1013
+ __riscv_vfmacc_vv_f32m2 (
1014
+ __riscv_vfmacc_vf_f32m2 (__riscv_vfmv_v_f_f32m2 (0x1.fffdb6p-2f , vl ), 0x1.555e66p-3f , b , vl ),
1015
+ __riscv_vfmacc_vf_f32m2 (__riscv_vfmv_v_f_f32m2 (0x1.573e2ep-5f , vl ), 0x1.0e4020p-7f , b , vl ),
1016
+ u , vl ), u , vl );
1017
+ if (!__riscv_vcpop_m_b16 (c , vl ))
1018
+ return __riscv_vfmacc_vv_f32m2 (k , j , k , vl );
1019
+ const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16 (n , 0.0f , vl );
1020
+ const vuint32m2_t d = __riscv_vmerge_vxm_u32m2 (__riscv_vmv_v_x_u32m2 (0 , vl ), 0x82000000 , dm , vl );
1021
+ const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2 (__riscv_vadd_vx_u32m2 (d , 0x7f000000 , vl ));
1022
+ const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2 (__riscv_vsub_vv_u32m2 (e , d , vl ));
1023
+ const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2 (
1024
+ __riscv_vfmacc_vv_f32m2 (k , k , j , vl ),
1025
+ __riscv_vfmul_vv_f32m2 (__riscv_vfmacc_vv_f32m2 (s2 , s2 , j , vl ), s1 , vl ),
1026
+ c , vl );
1027
+ return __riscv_vmerge_vvm_f32m2 (
1028
+ r1 , __riscv_vfmul_vv_f32m2 (s1 , s1 , vl ),
1029
+ __riscv_vmfgt_vf_f32m2_b16 (__riscv_vfabs_v_f32m2 (n , vl ), 192.0f , vl ),
1030
+ vl );
1031
+ }
1032
+
1033
+ #endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
990
1034
991
1035
inline static void ggml_vec_silu_f16 (const int n , ggml_fp16_t * y , const ggml_fp16_t * x ) {
992
1036
for (int i = 0 ; i < n ; ++ i ) {
0 commit comments