Skip to content

Commit c69e44a

Browse files
committed
Implement accelerated computation of (x << e) % y in unsigned integers
1 parent 1dd087a commit c69e44a

File tree

3 files changed

+254
-1
lines changed

3 files changed

+254
-1
lines changed

libm/src/math/support/int_traits/narrowing_div.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use crate::support::{CastInto, DInt, HInt, Int, MinInt, u256};
77
/// This is the inverse of widening multiplication:
88
/// - for any `x` and nonzero `y`: `x.widen_mul(y).checked_narrowing_div_rem(y) == Some((x, 0))`,
99
/// - and for any `r in 0..y`: `x.carrying_mul(y, r).checked_narrowing_div_rem(y) == Some((x, r))`,
10-
#[allow(dead_code)]
1110
pub trait NarrowingDiv: DInt + MinInt<Unsigned = Self> {
1211
/// Computes `(self / n, self % n))`
1312
///

libm/src/math/support/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub(crate) mod feature_detect;
88
mod float_traits;
99
pub mod hex_float;
1010
mod int_traits;
11+
mod modular;
1112

1213
#[allow(unused_imports)]
1314
pub use big::{i256, u256};
@@ -30,6 +31,8 @@ pub use hex_float::hf128;
3031
pub use hex_float::{hf32, hf64};
3132
#[allow(unused_imports)]
3233
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt, NarrowingDiv};
34+
#[allow(unused_imports)]
35+
pub use modular::linear_mul_reduction;
3336

3437
/// Hint to the compiler that the current path is cold.
3538
pub fn cold_path() {

libm/src/math/support/modular.rs

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
/* SPDX-License-Identifier: MIT OR Apache-2.0 */
2+
3+
//! To keep the equations somewhat concise, the following conventions are used:
4+
//! - all integer operations are in the mathematical sense, without overflow
5+
//! - concatenation means multiplication: `2xq = 2 * x * q`
6+
//! - `R = (1 << U::BITS)` is the modulus of wrapping arithmetic in `U`
7+
8+
use crate::support::int_traits::NarrowingDiv;
9+
use crate::support::{DInt, HInt, Int};
10+
11+
/// Compute the remainder `(x << e) % y` with unbounded integers.
12+
/// Requires `x < 2y` and `y.leading_zeros() >= 2`
13+
#[allow(dead_code)]
14+
pub fn linear_mul_reduction<U>(x: U, mut e: u32, mut y: U) -> U
15+
where
16+
U: HInt + Int<Unsigned = U>,
17+
U::D: NarrowingDiv,
18+
{
19+
assert!(y <= U::MAX >> 2);
20+
assert!(x < (y << 1));
21+
let _0 = U::ZERO;
22+
let _1 = U::ONE;
23+
24+
// power of two divisors
25+
if (y & (y - _1)).is_zero() {
26+
if e < U::BITS {
27+
// shift and only keep low bits
28+
return (x << e) & (y - _1);
29+
} else {
30+
// would shift out all the bits
31+
return _0;
32+
}
33+
}
34+
35+
// Use the identity `(x << e) % y == ((x << (e + s)) % (y << s)) >> s`
36+
// to shift the divisor so it has exactly two leading zeros to satisfy
37+
// the precondition of `Reducer::new`
38+
let s = y.leading_zeros() - 2;
39+
e += s;
40+
y <<= s;
41+
42+
// `m: Reducer` keeps track of the remainder `x` in a form that makes it
43+
// very efficient to do `x <<= k` modulo `y` for integers `k < U::BITS`
44+
let mut m = Reducer::new(x, y);
45+
46+
// Use the faster special case with constant `k == U::BITS - 1` while we can
47+
while e >= U::BITS - 1 {
48+
m.word_reduce();
49+
e -= U::BITS - 1;
50+
}
51+
// Finish with the variable shift operation
52+
m.shift_reduce(e);
53+
54+
// The partial remainder is in `[0, 2y)` ...
55+
let r = m.partial_remainder();
56+
// ... so check and correct, and compensate for the earlier shift.
57+
r.checked_sub(y).unwrap_or(r) >> s
58+
}
59+
60+
/// Helper type for computing the reductions. The implementation has a number
61+
/// of seemingly weird choices, but everything is aimed at streamlining
62+
/// `Reducer::word_reduce` into its current form.
63+
///
64+
/// Implicitly contains:
65+
/// n in (R/8, R/4)
66+
/// x in [0, 2n)
67+
/// The value of `n` is fixed for a given `Reducer`,
68+
/// but the value of `x` is modified by the methods.
69+
#[derive(Debug, Clone, PartialEq, Eq)]
70+
struct Reducer<U: HInt> {
71+
// m = 2n
72+
m: U,
73+
// q = (RR/2) / m
74+
// r = (RR/2) % m
75+
// Then RR/2 = qm + r, where `0 <= r < m`
76+
// The value `q` is only needed during construction, so isn't saved.
77+
r: U,
78+
// The value `x` is implicitly stored as `2 * q * x`:
79+
_2xq: U::D,
80+
}
81+
82+
impl<U> Reducer<U>
83+
where
84+
U: HInt,
85+
U: Int<Unsigned = U>,
86+
{
87+
/// Construct a reducer for `(x << _) mod n`.
88+
///
89+
/// Requires `R/8 < n < R/4` and `x < 2n`.
90+
fn new(x: U, n: U) -> Self
91+
where
92+
U::D: NarrowingDiv,
93+
{
94+
let _1 = U::ONE;
95+
assert!(n > (_1 << (U::BITS - 3)));
96+
assert!(n < (_1 << (U::BITS - 2)));
97+
let m = n << 1;
98+
assert!(x < m);
99+
100+
// We need q and r s.t. RR/2 = qm + r, and `0 <= r < m`
101+
// As R/4 < m < R/2,
102+
// we have R <= q < 2R
103+
// so let q = R + f
104+
// RR/2 = (R + f)m + r
105+
// R(R/2 - m) = fm + r
106+
107+
// v = R/2 - m < R/4 < m
108+
let v = (_1 << (U::BITS - 1)) - m;
109+
let (f, r) = v.widen_hi().checked_narrowing_div_rem(m).unwrap();
110+
111+
// xq < qm <= RR/2
112+
// 2xq < RR
113+
// 2xq = 2xR + 2xf;
114+
let _2x: U = x << 1;
115+
let _2xq = _2x.widen_hi() + _2x.widen_mul(f);
116+
Self { m, r, _2xq }
117+
}
118+
119+
/// Extract the current remainder in the range `[0, 2n)`
120+
fn partial_remainder(&self) -> U {
121+
// RR/2 = qm + r, 0 <= r < m
122+
// 2xq = uR + v, 0 <= v < R
123+
// muR = 2mxq - mv
124+
// = xRR - 2xr - mv
125+
// mu + (2xr + mv)/R == xR
126+
127+
// 0 <= 2xq < RR
128+
// R <= q < 2R
129+
// 0 <= x < R/2
130+
// R/4 < m < R/2
131+
// 0 <= r < m
132+
// 0 <= mv < mR
133+
// 0 <= 2xr < rR < mR
134+
135+
// 0 <= (2xr + mv)/R < 2m
136+
// Add `mu` to each term to obtain:
137+
// mu <= xR < mu + 2m
138+
139+
// Since `0 <= 2m < R`, `xR` is the only multiple of `R` between
140+
// `mu` and `m(u+2)`, so the high half of `m(u+2)` must equal `x`.
141+
let _1 = U::ONE;
142+
self.m.widen_mul(self._2xq.hi() + (_1 + _1)).hi()
143+
}
144+
145+
/// Replace the remainder `x` with `(x << k) - un`,
146+
/// for a suitable quotient `u`, which is returned.
147+
fn shift_reduce(&mut self, k: u32) -> U {
148+
assert!(k < U::BITS);
149+
// 2xq << k = aRR/2 + b;
150+
let a = self._2xq.hi() >> (U::BITS - 1 - k);
151+
let (low, high) = (self._2xq << k).lo_hi();
152+
let b = U::D::from_lo_hi(low, high & (U::MAX >> 1));
153+
154+
// (2xq << k) - aqm
155+
// = aRR/2 + b - aqm
156+
// = a(RR/2 - qm) + b
157+
// = ar + b
158+
self._2xq = a.widen_mul(self.r) + b;
159+
a
160+
}
161+
162+
/// Replace the remainder `x` with `x(R/2) - un`,
163+
/// for a suitable quotient `u`, which is returned.
164+
fn word_reduce(&mut self) -> U {
165+
// 2xq = uR + v
166+
let (v, u) = self._2xq.lo_hi();
167+
// xqR - uqm
168+
// = uRR/2 + vR/2 - uRR/2 + ur
169+
// = ur + (v/2)R
170+
self._2xq = u.widen_mul(self.r) + U::widen_hi(v >> 1);
171+
u
172+
}
173+
}
174+
175+
#[cfg(test)]
176+
mod test {
177+
use crate::support::linear_mul_reduction;
178+
use crate::support::modular::Reducer;
179+
180+
#[test]
181+
fn reducer_ops() {
182+
for n in 33..=63_u8 {
183+
for x in 0..2 * n {
184+
let temp = Reducer::new(x, n);
185+
let n = n as u32;
186+
let x0 = temp.partial_remainder() as u32;
187+
assert_eq!(x as u32, x0);
188+
for k in 0..=7 {
189+
let mut red = temp.clone();
190+
let u = red.shift_reduce(k) as u32;
191+
let x1 = red.partial_remainder() as u32;
192+
assert_eq!(x1, (x0 << k) - u * n);
193+
assert!(x1 < 2 * n);
194+
assert!((red._2xq as u32).is_multiple_of(2 * x1));
195+
196+
// `word_reduce` is equivalent to
197+
// `shift_reduce(U::BITS - 1)`
198+
if k == 7 {
199+
let mut alt = temp.clone();
200+
let w = alt.word_reduce();
201+
assert_eq!(u, w as u32);
202+
assert_eq!(alt, red);
203+
}
204+
}
205+
}
206+
}
207+
}
208+
#[test]
209+
fn reduction_u8() {
210+
for y in 1..64u8 {
211+
for x in 0..2 * y {
212+
let mut r = x % y;
213+
for e in 0..100 {
214+
assert_eq!(r, linear_mul_reduction(x, e, y));
215+
// maintain the correct expected remainder
216+
r <<= 1;
217+
if r >= y {
218+
r -= y;
219+
}
220+
}
221+
}
222+
}
223+
}
224+
#[test]
225+
fn reduction_u128() {
226+
assert_eq!(
227+
linear_mul_reduction::<u128>(17, 100, 123456789),
228+
(17 << 100) % 123456789
229+
);
230+
231+
// power-of-two divisor
232+
assert_eq!(
233+
linear_mul_reduction(0xdead_beef, 100, 1_u128 << 116),
234+
0xbeef << 100
235+
);
236+
237+
let x = 10_u128.pow(37);
238+
let y = 11_u128.pow(36);
239+
assert!(x < y);
240+
let mut r = x;
241+
for e in 0..1000 {
242+
assert_eq!(r, linear_mul_reduction(x, e, y));
243+
// maintain the correct expected remainder
244+
r <<= 1;
245+
if r >= y {
246+
r -= y;
247+
}
248+
assert!(r != 0);
249+
}
250+
}
251+
}

0 commit comments

Comments
 (0)