Skip to content

Commit ae4922c

Browse files
committed
AArch64: Add native implementation of poly_decompose
This add a native implementation of poly_decompose written from scratch. Resolves #397 Signed-off-by: Matthias J. Kannwischer <[email protected]>
1 parent 0f865ed commit ae4922c

File tree

4 files changed

+246
-0
lines changed

4 files changed

+246
-0
lines changed

mldsa/native/aarch64/meta.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#define MLD_USE_NATIVE_REJ_UNIFORM
1414
#define MLD_USE_NATIVE_REJ_UNIFORM_ETA2
1515
#define MLD_USE_NATIVE_REJ_UNIFORM_ETA4
16+
#define MLD_USE_NATIVE_POLY_DECOMPOSE_32
17+
#define MLD_USE_NATIVE_POLY_DECOMPOSE_88
1618

1719
/* Identifier for this backend so that source and assembly files
1820
* in the build can be appropriately guarded. */
@@ -93,6 +95,18 @@ static MLD_INLINE int mld_rej_uniform_eta4_native(int32_t *r, unsigned len,
9395
return outlen;
9496
}
9597

98+
static MLD_INLINE void mld_poly_decompose_32_native(int32_t *a1, int32_t *a0,
99+
const int32_t *a)
100+
{
101+
mld_poly_decompose_32_asm(a1, a0, a);
102+
}
103+
104+
static MLD_INLINE void mld_poly_decompose_88_native(int32_t *a1, int32_t *a0,
105+
const int32_t *a)
106+
{
107+
mld_poly_decompose_88_asm(a1, a0, a);
108+
}
109+
96110
#endif /* !__ASSEMBLER__ */
97111

98112
#endif /* !MLD_NATIVE_AARCH64_META_H */

mldsa/native/aarch64/src/arith_native_aarch64.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,10 @@ unsigned mld_rej_uniform_eta2_asm(int32_t *r, const uint8_t *buf,
6262
unsigned mld_rej_uniform_eta4_asm(int32_t *r, const uint8_t *buf,
6363
unsigned buflen, const uint8_t *table);
6464

65+
#define mld_poly_decompose_32_asm MLD_NAMESPACE(poly_decompose_32_asm)
66+
void mld_poly_decompose_32_asm(int32_t *a1, int32_t *a0, const int32_t *a);
67+
68+
#define mld_poly_decompose_88_asm MLD_NAMESPACE(poly_decompose_88_asm)
69+
void mld_poly_decompose_88_asm(int32_t *a1, int32_t *a0, const int32_t *a);
70+
6571
#endif /* !MLD_NATIVE_AARCH64_SRC_ARITH_NATIVE_AARCH64_H */
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) The mldsa-native project authors
3+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
4+
*/
5+
#include "../../../common.h"
6+
7+
#if defined(MLD_ARITH_BACKEND_AARCH64) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)
8+
9+
.macro decompose32 a1, a0, input, temp
10+
// Step 1: Compute ceil(a / 128) using floor((a + 127) / 128)
11+
// This is the first part of computing a1 = floor(a / (2*GAMMA2))
12+
// where 2*GAMMA2 = 523776. We break this into two steps:
13+
// ceil(a / 128) followed by round(temp / 4092)
14+
add \a1\().4s, \input\().4s, offset_127.4s
15+
ushr \a1\().4s, \a1\().4s, #7
16+
17+
// Step 2: Barrett reduction with rounding: round(temp * 1025 / 2^22)
18+
// This computes: round(ceil(a/128) / 4092)
19+
// Combined: a1 ≈ round(ceil(a/128) / 4092) ≈ floor(a / 523776)
20+
// sqrdmulh computes: (2 * temp * 524800 + 2^31) >> 32
21+
// which is equivalent to: (temp * 1025 + 2^21) >> 22.
22+
sqrdmulh \a1\().4s, \a1\().4s, barrett_const.4s
23+
24+
// Step 3: Mask to valid range [0, 14] since (Q-1)/(2*GAMMA2) = 15
25+
and \a1\().16b, \a1\().16b, mask_15.16b
26+
27+
// Step 4: Compute a0 = a - a1 * 2*GAMMA2 (low part of decomposition)
28+
mls \input\().4s, \a1\().4s, gamma2_2x.4s
29+
30+
// Step 5: Conditional reduction: if a0 > (Q-1)/2 then a0 -= Q
31+
cmgt \temp\().4s, \input\().4s, q_half.4s
32+
and \temp\().16b, \temp\().16b, q.16b
33+
sub \a0\().4s, \input\().4s, \temp\().4s
34+
.endm
35+
36+
/* Parameters */
37+
a1_ptr .req x0 // Output polynomial with coefficients c1
38+
a0_ptr .req x1 // Output polynomial with coefficients c0
39+
a_ptr .req x2 // Input polynomial
40+
41+
count .req x3
42+
43+
/* Constant register assignments */
44+
q .req v20 // Q = 8380417
45+
q_half .req v21 // (Q-1)/2
46+
gamma2_2x .req v22 // 2*GAMMA2 = 523776
47+
mask_15 .req v23 // mask = 15
48+
offset_127 .req v24 // offset = 127
49+
barrett_const .req v25 // Barrett constant = 524800
50+
51+
52+
.text
53+
.global MLD_ASM_NAMESPACE(poly_decompose_32_asm)
54+
.balign 4
55+
MLD_ASM_FN_SYMBOL(poly_decompose_32_asm)
56+
// Load constants into SIMD registers
57+
movz w4, #57345
58+
movk w4, #127, lsl #16
59+
dup q.4s, w4
60+
61+
lsr w5, w5, #1
62+
dup q_half.4s, w5
63+
64+
movz w7, #0xfe00
65+
movk w7, #7, lsl #16
66+
dup gamma2_2x.4s, w7
67+
68+
movi mask_15.4s, #15
69+
movi offset_127.4s, #127
70+
71+
movz w11, #0x0200
72+
movk w11, #8, lsl #16
73+
dup barrett_const.4s, w11
74+
75+
mov count, #(64/4)
76+
77+
poly_decompose_32_loop:
78+
ldr q1, [a_ptr, #1*16]
79+
ldr q2, [a_ptr, #2*16]
80+
ldr q3, [a_ptr, #3*16]
81+
ldr q0, [a_ptr], #4*16
82+
83+
decompose32 v4, v5, v1, v26
84+
decompose32 v6, v7, v2, v26
85+
decompose32 v16, v17, v3, v26
86+
decompose32 v18, v19, v0, v26
87+
88+
89+
str q4, [a1_ptr, #1*16]
90+
str q6, [a1_ptr, #2*16]
91+
str q16, [a1_ptr, #3*16]
92+
str q18, [a1_ptr], #4*16
93+
str q5, [a0_ptr, #1*16]
94+
str q7, [a0_ptr, #2*16]
95+
str q17, [a0_ptr, #3*16]
96+
str q19, [a0_ptr], #4*16
97+
98+
subs count, count, #1
99+
bne poly_decompose_32_loop
100+
101+
ret
102+
103+
.unreq a1_ptr
104+
.unreq a0_ptr
105+
.unreq a_ptr
106+
.unreq count
107+
.unreq q
108+
.unreq q_half
109+
.unreq gamma2_2x
110+
.unreq mask_15
111+
.unreq offset_127
112+
.unreq barrett_const
113+
114+
#endif /* MLD_ARITH_BACKEND_AARCH64 && !MLD_CONFIG_MULTILEVEL_NO_SHARED */
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Copyright (c) The mldsa-native project authors
3+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
4+
*/
5+
#include "../../../common.h"
6+
7+
#if defined(MLD_ARITH_BACKEND_AARCH64) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)
8+
9+
.macro decompose88 a1, a0, input, temp
10+
// Step 1: Compute ceil(a / 128) using floor((a + 127) / 128)
11+
// This is the first part of computing a1 = floor(a / (2*GAMMA2))
12+
// where 2*GAMMA2 = 190464. We break this into two steps:
13+
// ceil(a / 128) followed by round(temp / 1488)
14+
add \a1\().4s, \input\().4s, offset_127.4s
15+
ushr \a1\().4s, \a1\().4s, #7
16+
17+
// Step 2: Barrett reduction with rounding: round(temp * 11275 / 2^24)
18+
// This computes: round(ceil(a/128) / 1488)
19+
// Combined: a1 ≈ round(ceil(a/128) / 1488) ≈ floor(a / 190464)
20+
// sqrdmulh computes: (2 * temp * 1443201 + 2^31) >> 32
21+
// which is equivalent to: (temp * 11275 + 2^23) >> 24.
22+
sqrdmulh \a1\().4s, \a1\().4s, barrett_const.4s
23+
24+
// Step 3: Mask to valid range [0, 43] since (Q-1)/(2*GAMMA2) = 44
25+
cmlt \temp\().4s, \a1\().4s, constant_44.4s
26+
and \a1\().16b, \a1\().16b, \temp\().16b
27+
28+
// Step 4: Compute a0 = a - a1 * 2*GAMMA2 (low part of decomposition)
29+
mls \input\().4s, \a1\().4s, gamma2_2x.4s
30+
31+
// Step 5: Conditional reduction: if a0 > (Q-1)/2 then a0 -= Q
32+
cmgt \temp\().4s, \input\().4s, q_half.4s
33+
and \temp\().16b, \temp\().16b, q.16b
34+
sub \a0\().4s, \input\().4s, \temp\().4s
35+
.endm
36+
37+
/* Parameters */
38+
a1_ptr .req x0 // Output polynomial with coefficients c1
39+
a0_ptr .req x1 // Output polynomial with coefficients c0
40+
a_ptr .req x2 // Input polynomial
41+
42+
count .req x3
43+
44+
/* Constant register assignments */
45+
q .req v20 // Q = 8380417
46+
q_half .req v21 // (Q-1)/2
47+
gamma2_2x .req v22 // 2*GAMMA2 = 190464
48+
constant_44 .req v23 // const = 44
49+
offset_127 .req v24 // offset = 127
50+
barrett_const .req v25 // Barrett constant = 1443201
51+
52+
.text
53+
.global MLD_ASM_NAMESPACE(poly_decompose_88_asm)
54+
.balign 4
55+
MLD_ASM_FN_SYMBOL(poly_decompose_88_asm)
56+
// Load constants into SIMD registers
57+
movz w4, #57345
58+
movk w4, #127, lsl #16
59+
dup q.4s, w4
60+
61+
lsr w5, w5, #1
62+
dup q_half.4s, w5
63+
64+
movz w7, #0xe800
65+
movk w7, #0x2, lsl #16
66+
dup gamma2_2x.4s, w7
67+
68+
movi constant_44.4s, #44
69+
movi offset_127.4s, #127
70+
71+
movz w11, #0x0581
72+
movk w11, #0x16, lsl #16
73+
dup barrett_const.4s, w11
74+
75+
mov count, #(64/4)
76+
poly_decompose_88_loop:
77+
ldr q1, [a_ptr, #1*16]
78+
ldr q2, [a_ptr, #2*16]
79+
ldr q3, [a_ptr, #3*16]
80+
ldr q0, [a_ptr], #4*16
81+
82+
decompose88 v4, v5, v1, v26
83+
decompose88 v6, v7, v2, v26
84+
decompose88 v16, v17, v3, v26
85+
decompose88 v18, v19, v0, v2
86+
87+
str q4, [a1_ptr, #1*16]
88+
str q6, [a1_ptr, #2*16]
89+
str q16, [a1_ptr, #3*16]
90+
str q18, [a1_ptr], #4*16
91+
str q5, [a0_ptr, #1*16]
92+
str q7, [a0_ptr, #2*16]
93+
str q17, [a0_ptr, #3*16]
94+
str q19, [a0_ptr], #4*16
95+
96+
subs count, count, #1
97+
bne poly_decompose_88_loop
98+
99+
ret
100+
101+
.unreq a1_ptr
102+
.unreq a0_ptr
103+
.unreq a_ptr
104+
.unreq count
105+
.unreq q
106+
.unreq q_half
107+
.unreq gamma2_2x
108+
.unreq constant_44
109+
.unreq offset_127
110+
.unreq barrett_const
111+
112+
#endif /* MLD_ARITH_BACKEND_AARCH64 && !MLD_CONFIG_MULTILEVEL_NO_SHARED */

0 commit comments

Comments
 (0)