Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mldsa/native/aarch64/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#define MLD_USE_NATIVE_REJ_UNIFORM
#define MLD_USE_NATIVE_REJ_UNIFORM_ETA2
#define MLD_USE_NATIVE_REJ_UNIFORM_ETA4
#define MLD_USE_NATIVE_POLY_DECOMPOSE_32
#define MLD_USE_NATIVE_POLY_DECOMPOSE_88

/* Identifier for this backend so that source and assembly files
* in the build can be appropriately guarded. */
Expand Down Expand Up @@ -93,6 +95,18 @@ static MLD_INLINE int mld_rej_uniform_eta4_native(int32_t *r, unsigned len,
return outlen;
}

static MLD_INLINE void mld_poly_decompose_32_native(int32_t *a1, int32_t *a0,
const int32_t *a)
{
mld_poly_decompose_32_asm(a1, a0, a);
}

static MLD_INLINE void mld_poly_decompose_88_native(int32_t *a1, int32_t *a0,
const int32_t *a)
{
mld_poly_decompose_88_asm(a1, a0, a);
}

#endif /* !__ASSEMBLER__ */

#endif /* !MLD_NATIVE_AARCH64_META_H */
6 changes: 6 additions & 0 deletions mldsa/native/aarch64/src/arith_native_aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,10 @@ unsigned mld_rej_uniform_eta2_asm(int32_t *r, const uint8_t *buf,
unsigned mld_rej_uniform_eta4_asm(int32_t *r, const uint8_t *buf,
unsigned buflen, const uint8_t *table);

#define mld_poly_decompose_32_asm MLD_NAMESPACE(poly_decompose_32_asm)
void mld_poly_decompose_32_asm(int32_t *a1, int32_t *a0, const int32_t *a);

#define mld_poly_decompose_88_asm MLD_NAMESPACE(poly_decompose_88_asm)
void mld_poly_decompose_88_asm(int32_t *a1, int32_t *a0, const int32_t *a);

#endif /* !MLD_NATIVE_AARCH64_SRC_ARITH_NATIVE_AARCH64_H */
114 changes: 114 additions & 0 deletions mldsa/native/aarch64/src/poly_decompose_32_asm.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) The mldsa-native project authors
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
*/
#include "../../../common.h"

#if defined(MLD_ARITH_BACKEND_AARCH64) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)

.macro decompose32 a1, a0, input, temp
// Step 1: Compute ceil(a / 128) using floor((a + 127) / 128)
// This is the first part of computing a1 = floor(a / (2*GAMMA2))
// where 2*GAMMA2 = 523776. We break this into two steps:
// ceil(a / 128) followed by round(temp / 4092)
add \a1\().4s, \input\().4s, offset_127.4s
ushr \a1\().4s, \a1\().4s, #7

// Step 2: Barrett reduction with rounding: round(temp * 1025 / 2^22)
// This computes: round(ceil(a/128) / 4092)
// Combined: a1 ≈ round(ceil(a/128) / 4092) ≈ floor(a / 523776)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the second equality hold?

// sqrdmulh computes: (2 * temp * 524800 + 2^31) >> 32
// which is equivalent to: (temp * 1025 + 2^21) >> 22.
sqrdmulh \a1\().4s, \a1\().4s, barrett_const.4s

// Step 3: Mask to valid range [0, 14] since (Q-1)/(2*GAMMA2) = 15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment would suggest that 15 is excluded after the masking, but it isn't, seeing that mask_15 is elementwise0x0F?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the comment is wrong. the valid range is [0, 15]. will fix.

and \a1\().16b, \a1\().16b, mask_15.16b

// Step 4: Compute a0 = a - a1 * 2*GAMMA2 (low part of decomposition)
mls \input\().4s, \a1\().4s, gamma2_2x.4s

// Step 5: Conditional reduction: if a0 > (Q-1)/2 then a0 -= Q
cmgt \temp\().4s, \input\().4s, q_half.4s
and \temp\().16b, \temp\().16b, q.16b
sub \a0\().4s, \input\().4s, \temp\().4s
.endm

/* Parameters */
a1_ptr .req x0 // Output polynomial with coefficients c1
a0_ptr .req x1 // Output polynomial with coefficients c0
a_ptr .req x2 // Input polynomial

count .req x3

/* Constant register assignments */
q .req v20 // Q = 8380417
q_half .req v21 // (Q-1)/2
gamma2_2x .req v22 // 2*GAMMA2 = 523776
mask_15 .req v23 // mask = 15
offset_127 .req v24 // offset = 127
barrett_const .req v25 // Barrett constant = 524800


.text
.global MLD_ASM_NAMESPACE(poly_decompose_32_asm)
.balign 4
MLD_ASM_FN_SYMBOL(poly_decompose_32_asm)
// Load constants into SIMD registers
movz w4, #57345
movk w4, #127, lsl #16
dup q.4s, w4

lsr w5, w4, #1
dup q_half.4s, w5

movz w7, #0xfe00
movk w7, #7, lsl #16
dup gamma2_2x.4s, w7

movi mask_15.4s, #15
movi offset_127.4s, #127

movz w11, #0x0200
movk w11, #8, lsl #16
dup barrett_const.4s, w11

mov count, #(64/4)

poly_decompose_32_loop:
ldr q1, [a_ptr, #1*16]
ldr q2, [a_ptr, #2*16]
ldr q3, [a_ptr, #3*16]
ldr q0, [a_ptr], #4*16

decompose32 v4, v5, v1, v26
decompose32 v6, v7, v2, v26
decompose32 v16, v17, v3, v26
decompose32 v18, v19, v0, v26


str q4, [a1_ptr, #1*16]
str q6, [a1_ptr, #2*16]
str q16, [a1_ptr, #3*16]
str q18, [a1_ptr], #4*16
str q5, [a0_ptr, #1*16]
str q7, [a0_ptr, #2*16]
str q17, [a0_ptr, #3*16]
str q19, [a0_ptr], #4*16

subs count, count, #1
bne poly_decompose_32_loop

ret

.unreq a1_ptr
.unreq a0_ptr
.unreq a_ptr
.unreq count
.unreq q
.unreq q_half
.unreq gamma2_2x
.unreq mask_15
.unreq offset_127
.unreq barrett_const

#endif /* MLD_ARITH_BACKEND_AARCH64 && !MLD_CONFIG_MULTILEVEL_NO_SHARED */
112 changes: 112 additions & 0 deletions mldsa/native/aarch64/src/poly_decompose_88_asm.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright (c) The mldsa-native project authors
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
*/
#include "../../../common.h"

#if defined(MLD_ARITH_BACKEND_AARCH64) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)

.macro decompose88 a1, a0, input, temp
// Step 1: Compute ceil(a / 128) using floor((a + 127) / 128)
// This is the first part of computing a1 = floor(a / (2*GAMMA2))
// where 2*GAMMA2 = 190464. We break this into two steps:
// ceil(a / 128) followed by round(temp / 1488)
add \a1\().4s, \input\().4s, offset_127.4s
ushr \a1\().4s, \a1\().4s, #7

// Step 2: Barrett reduction with rounding: round(temp * 11275 / 2^24)
// This computes: round(ceil(a/128) / 1488)
// Combined: a1 ≈ round(ceil(a/128) / 1488) ≈ floor(a / 190464)
// sqrdmulh computes: (2 * temp * 1443201 + 2^31) >> 32
// which is equivalent to: (temp * 11275 + 2^23) >> 24.
sqrdmulh \a1\().4s, \a1\().4s, barrett_const.4s

// Step 3: Mask to valid range [0, 43] since (Q-1)/(2*GAMMA2) = 44
cmhi \temp\().4s, const44.4s, \a1\().4s
and \a1\().16b, \a1\().16b, \temp\().16b

// Step 4: Compute a0 = a - a1 * 2*GAMMA2 (low part of decomposition)
mls \input\().4s, \a1\().4s, gamma2_2x.4s

// Step 5: Conditional reduction: if a0 > (Q-1)/2 then a0 -= Q
cmgt \temp\().4s, \input\().4s, q_half.4s
and \temp\().16b, \temp\().16b, q.16b
sub \a0\().4s, \input\().4s, \temp\().4s
.endm

/* Parameters */
a1_ptr .req x0 // Output polynomial with coefficients c1
a0_ptr .req x1 // Output polynomial with coefficients c0
a_ptr .req x2 // Input polynomial

count .req x3

/* Constant register assignments */
q .req v20 // Q = 8380417
q_half .req v21 // (Q-1)/2
gamma2_2x .req v22 // 2*GAMMA2 = 190464
const44 .req v23 // const = 44
offset_127 .req v24 // offset = 127
barrett_const .req v25 // Barrett constant = 1443201

.text
.global MLD_ASM_NAMESPACE(poly_decompose_88_asm)
.balign 4
MLD_ASM_FN_SYMBOL(poly_decompose_88_asm)
// Load constants into SIMD registers
movz w4, #57345
movk w4, #127, lsl #16
dup q.4s, w4

lsr w5, w4, #1
dup q_half.4s, w5

movz w7, #0xe800
movk w7, #0x2, lsl #16
dup gamma2_2x.4s, w7

movi const44.4s, #44
movi offset_127.4s, #127

movz w11, #0x0581
movk w11, #0x16, lsl #16
dup barrett_const.4s, w11

mov count, #(64/4)
poly_decompose_88_loop:
ldr q1, [a_ptr, #1*16]
ldr q2, [a_ptr, #2*16]
ldr q3, [a_ptr, #3*16]
ldr q0, [a_ptr], #4*16

decompose88 v4, v5, v1, v26
decompose88 v6, v7, v2, v26
decompose88 v16, v17, v3, v26
decompose88 v18, v19, v0, v26

str q4, [a1_ptr, #1*16]
str q6, [a1_ptr, #2*16]
str q16, [a1_ptr, #3*16]
str q18, [a1_ptr], #4*16
str q5, [a0_ptr, #1*16]
str q7, [a0_ptr, #2*16]
str q17, [a0_ptr, #3*16]
str q19, [a0_ptr], #4*16

subs count, count, #1
bne poly_decompose_88_loop

ret

.unreq a1_ptr
.unreq a0_ptr
.unreq a_ptr
.unreq count
.unreq q
.unreq q_half
.unreq gamma2_2x
.unreq const44
.unreq offset_127
.unreq barrett_const

#endif /* MLD_ARITH_BACKEND_AARCH64 && !MLD_CONFIG_MULTILEVEL_NO_SHARED */
43 changes: 43 additions & 0 deletions mldsa/native/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,47 @@ static MLD_INLINE int mld_rej_uniform_eta4_native(int32_t *r, unsigned len,
unsigned buflen);
#endif /* MLD_USE_NATIVE_REJ_UNIFORM_ETA4 */

#if defined(MLD_USE_NATIVE_POLY_DECOMPOSE_32)
/*************************************************
* Name: mld_poly_decompose_32_native
*
* Description: Native implementation of poly_decompose for GAMMA2 = (Q-1)/32.
* For all coefficients c of the input polynomial,
* compute high and low bits c0, c1 such
* c mod MLDSA_Q = c1*(2*GAMMA2) + c0
* with -(2*GAMMA2)/2 < c0 <= (2*GAMMA2)/2 except
* c1 = (MLDSA_Q-1)/(2*GAMMA2) where we set
* c1 = 0 and -(2*GAMMA2)/2 <= c0 = c mod MLDSA_Q - MLDSA_Q < 0.
* Assumes coefficients to be standard representatives.
*
* Arguments: - int32_t *a1: output polynomial with coefficients c1
* - int32_t *a0: output polynomial with coefficients c0
* - const int32_t *a: input polynomial
**************************************************/
static MLD_INLINE void mld_poly_decompose_32_native(int32_t *a1, int32_t *a0,
const int32_t *a);
#endif /* MLD_USE_NATIVE_POLY_DECOMPOSE_32 */

#if defined(MLD_USE_NATIVE_POLY_DECOMPOSE_88)
/*************************************************
* Name: mld_poly_decompose_88_native
*
* Description: Native implementation of poly_decompose for GAMMA2 = (Q-1)/88.
* For all coefficients c of the input polynomial,
* compute high and low bits c0, c1 such
* c mod MLDSA_Q = c1*(2*GAMMA2) + c0
* with -(2*GAMMA2)/2 < c0 <= (2*GAMMA2)/2 except
* c1 = (MLDSA_Q-1)/(2*GAMMA2) where we set
* c1 = 0 and -(2*GAMMA2)/2 <= c0 = c mod MLDSA_Q - MLDSA_Q < 0.
* Assumes coefficients to be standard representatives.
*
* Arguments: - int32_t *a1: output polynomial with coefficients c1
* - int32_t *a0: output polynomial with coefficients c0
* - const int32_t *a: input polynomial
**************************************************/
static MLD_INLINE void mld_poly_decompose_88_native(int32_t *a1, int32_t *a0,
const int32_t *a);
#endif /* MLD_USE_NATIVE_POLY_DECOMPOSE_88 */


#endif /* !MLD_NATIVE_API_H */
14 changes: 14 additions & 0 deletions mldsa/native/x86_64/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#define MLD_USE_NATIVE_REJ_UNIFORM
#define MLD_USE_NATIVE_REJ_UNIFORM_ETA2
#define MLD_USE_NATIVE_REJ_UNIFORM_ETA4
#define MLD_USE_NATIVE_POLY_DECOMPOSE_32
#define MLD_USE_NATIVE_POLY_DECOMPOSE_88

#if !defined(__ASSEMBLER__)
#include <string.h>
Expand Down Expand Up @@ -98,6 +100,18 @@ static MLD_INLINE int mld_rej_uniform_eta4_native(int32_t *r, unsigned len,
return outlen;
}

static MLD_INLINE void mld_poly_decompose_32_native(int32_t *a1, int32_t *a0,
const int32_t *a)
{
mld_poly_decompose_32_avx2((__m256i *)a1, (__m256i *)a0, (const __m256i *)a);
}

static MLD_INLINE void mld_poly_decompose_88_native(int32_t *a1, int32_t *a0,
const int32_t *a)
{
mld_poly_decompose_88_avx2((__m256i *)a1, (__m256i *)a0, (const __m256i *)a);
}

#endif /* !__ASSEMBLER__ */

#endif /* !MLD_NATIVE_X86_64_META_H */
6 changes: 6 additions & 0 deletions mldsa/native/x86_64/src/arith_native_x86_64.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,10 @@ unsigned mld_rej_uniform_eta2_avx2(
unsigned mld_rej_uniform_eta4_avx2(
int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN]);

#define mld_poly_decompose_32_avx2 MLD_NAMESPACE(mld_poly_decompose_32_avx2)
void mld_poly_decompose_32_avx2(__m256i *a1, __m256i *a0, const __m256i *a);

#define mld_poly_decompose_88_avx2 MLD_NAMESPACE(mld_poly_decompose_88_avx2)
void mld_poly_decompose_88_avx2(__m256i *a1, __m256i *a0, const __m256i *a);

#endif /* !MLD_NATIVE_X86_64_SRC_ARITH_NATIVE_X86_64_H */
Loading
Loading