Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 57 additions & 9 deletions libm/src/math/generic/fmod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
/* SPDX-License-Identifier: MIT OR Apache-2.0 */
use crate::support::{CastFrom, Float, Int, MinInt};
use crate::support::{CastFrom, CastInto, Float, HInt, Int, MinInt, NarrowingDiv};

#[inline]
pub fn fmod<F: Float>(x: F, y: F) -> F {
pub fn fmod<F: Float>(x: F, y: F) -> F
where
F::Int: HInt,
<F::Int as HInt>::D: NarrowingDiv,
{
let _1 = F::Int::ONE;
let sx = x.to_bits() & F::SIGN_MASK;
let ux = x.to_bits() & !F::SIGN_MASK;
Expand All @@ -29,7 +33,7 @@ pub fn fmod<F: Float>(x: F, y: F) -> F {

// To compute `(num << ex) % (div << ey)`, first
// evaluate `rem = (num << (ex - ey)) % div` ...
let rem = reduction(num, ex - ey, div);
let rem = reduction::<F>(num, ex - ey, div);
// ... so the result will be `rem << ey`

if rem.is_zero() {
Expand Down Expand Up @@ -58,11 +62,55 @@ fn into_sig_exp<F: Float>(mut bits: F::Int) -> (F::Int, u32) {
}

/// Compute the remainder `(x * 2.pow(e)) % y` without overflow.
fn reduction<I: Int>(mut x: I, e: u32, y: I) -> I {
x %= y;
for _ in 0..e {
x <<= 1;
x = x.checked_sub(y).unwrap_or(x);
fn reduction<F>(mut x: F::Int, e: u32, y: F::Int) -> F::Int
where
F: Float,
F::Int: HInt,
<<F as Float>::Int as HInt>::D: NarrowingDiv,
{
// `f16` only has 5 exponent bits, so even `f16::MAX = 65504.0` is only
// a 40-bit integer multiple of the smallest subnormal.
if F::BITS == 16 {
debug_assert!(F::EXP_MAX - F::EXP_MIN == 29);
debug_assert!(e <= 29);
let u: u16 = x.cast();
let v: u16 = y.cast();
let u = (u as u64) << e;
let v = v as u64;
return F::Int::cast_from((u % v) as u16);
}
x

// Ensure `x < 2y` for later steps
if x >= (y << 1) {
// This case is only reached with subnormal divisors,
// but it might be better to just normalize all significands
// to make this unnecessary. The further calls could potentially
// benefit from assuming a specific fixed leading bit position.
x %= y;
}

// The simple implementation seems to be fastest for a short reduction
// at this size. The limit here was chosen empirically on an Intel Nehalem.
// Less old CPUs that have faster `u64 * u64 -> u128` might not benefit,
// and 32-bit systems or architectures without hardware multipliers might
// want to do this in more cases.
if F::BITS == 64 && e < 32 {
// Assumes `x < 2y`
for _ in 0..e {
x = x.checked_sub(y).unwrap_or(x);
x <<= 1;
}
return x.checked_sub(y).unwrap_or(x);
}

// Fast path for short reductions
if e < F::BITS {
let w = x.widen() << e;
if let Some((_, r)) = w.checked_narrowing_div_rem(y) {
return r;
}
}

// Assumes `x < 2y`
crate::support::linear_mul_reduction(x, e, y)
}
9 changes: 8 additions & 1 deletion libm/src/math/support/int_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,14 @@ int_impl!(i128, u128);

/// Trait for integers twice the bit width of another integer. This is implemented for all
/// primitives except for `u8`, because there is not a smaller primitive.
pub trait DInt: MinInt {
pub trait DInt:
MinInt
+ ops::Add<Output = Self>
+ ops::Sub<Output = Self>
+ ops::Shl<u32, Output = Self>
+ ops::Shr<u32, Output = Self>
+ Ord
{
/// Integer that is half the bit width of the integer this trait is implemented for
type H: HInt<D = Self>;

Expand Down
1 change: 0 additions & 1 deletion libm/src/math/support/int_traits/narrowing_div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use crate::support::{CastInto, DInt, HInt, Int, MinInt, u256};
/// This is the inverse of widening multiplication:
/// - for any `x` and nonzero `y`: `x.widen_mul(y).checked_narrowing_div_rem(y) == Some((x, 0))`,
/// - and for any `r in 0..y`: `x.carrying_mul(y, r).checked_narrowing_div_rem(y) == Some((x, r))`,
#[allow(dead_code)]
pub trait NarrowingDiv: DInt + MinInt<Unsigned = Self> {
/// Computes `(self / n, self % n))`
///
Expand Down
3 changes: 2 additions & 1 deletion libm/src/math/support/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub(crate) mod feature_detect;
mod float_traits;
pub mod hex_float;
mod int_traits;
mod modular;

#[allow(unused_imports)]
pub use big::{i256, u256};
Expand All @@ -28,8 +29,8 @@ pub use hex_float::hf16;
pub use hex_float::hf128;
#[allow(unused_imports)]
pub use hex_float::{hf32, hf64};
#[allow(unused_imports)]
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt, NarrowingDiv};
pub use modular::linear_mul_reduction;

/// Hint to the compiler that the current path is cold.
pub fn cold_path() {
Expand Down
251 changes: 251 additions & 0 deletions libm/src/math/support/modular.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
/* SPDX-License-Identifier: MIT OR Apache-2.0 */

//! To keep the equations somewhat concise, the following conventions are used:
//! - all integer operations are in the mathematical sense, without overflow
//! - concatenation means multiplication: `2xq = 2 * x * q`
//! - `R = (1 << U::BITS)` is the modulus of wrapping arithmetic in `U`

use crate::support::int_traits::NarrowingDiv;
use crate::support::{DInt, HInt, Int};

/// Compute the remainder `(x << e) % y` with unbounded integers.
/// Requires `x < 2y` and `y.leading_zeros() >= 2`
#[allow(dead_code)]
pub fn linear_mul_reduction<U>(x: U, mut e: u32, mut y: U) -> U
where
U: HInt + Int<Unsigned = U>,
U::D: NarrowingDiv,
{
assert!(y <= U::MAX >> 2);
assert!(x < (y << 1));
let _0 = U::ZERO;
let _1 = U::ONE;

// power of two divisors
if (y & (y - _1)).is_zero() {
if e < U::BITS {
// shift and only keep low bits
return (x << e) & (y - _1);
} else {
// would shift out all the bits
return _0;
}
}

// Use the identity `(x << e) % y == ((x << (e + s)) % (y << s)) >> s`
// to shift the divisor so it has exactly two leading zeros to satisfy
// the precondition of `Reducer::new`
let s = y.leading_zeros() - 2;
e += s;
y <<= s;

// `m: Reducer` keeps track of the remainder `x` in a form that makes it
// very efficient to do `x <<= k` modulo `y` for integers `k < U::BITS`
let mut m = Reducer::new(x, y);

// Use the faster special case with constant `k == U::BITS - 1` while we can
while e >= U::BITS - 1 {
m.word_reduce();
e -= U::BITS - 1;
}
// Finish with the variable shift operation
m.shift_reduce(e);

// The partial remainder is in `[0, 2y)` ...
let r = m.partial_remainder();
// ... so check and correct, and compensate for the earlier shift.
r.checked_sub(y).unwrap_or(r) >> s
}

/// Helper type for computing the reductions. The implementation has a number
/// of seemingly weird choices, but everything is aimed at streamlining
/// `Reducer::word_reduce` into its current form.
///
/// Implicitly contains:
/// n in (R/8, R/4)
/// x in [0, 2n)
/// The value of `n` is fixed for a given `Reducer`,
/// but the value of `x` is modified by the methods.
#[derive(Debug, Clone, PartialEq, Eq)]
struct Reducer<U: HInt> {
// m = 2n
m: U,
// q = (RR/2) / m
// r = (RR/2) % m
// Then RR/2 = qm + r, where `0 <= r < m`
// The value `q` is only needed during construction, so isn't saved.
r: U,
// The value `x` is implicitly stored as `2 * q * x`:
_2xq: U::D,
}

impl<U> Reducer<U>
where
U: HInt,
U: Int<Unsigned = U>,
{
/// Construct a reducer for `(x << _) mod n`.
///
/// Requires `R/8 < n < R/4` and `x < 2n`.
fn new(x: U, n: U) -> Self
where
U::D: NarrowingDiv,
{
let _1 = U::ONE;
assert!(n > (_1 << (U::BITS - 3)));
assert!(n < (_1 << (U::BITS - 2)));
let m = n << 1;
assert!(x < m);

// We need q and r s.t. RR/2 = qm + r, and `0 <= r < m`
// As R/4 < m < R/2,
// we have R <= q < 2R
// so let q = R + f
// RR/2 = (R + f)m + r
// R(R/2 - m) = fm + r

// v = R/2 - m < R/4 < m
let v = (_1 << (U::BITS - 1)) - m;
let (f, r) = v.widen_hi().checked_narrowing_div_rem(m).unwrap();

// xq < qm <= RR/2
// 2xq < RR
// 2xq = 2xR + 2xf;
let _2x: U = x << 1;
let _2xq = _2x.widen_hi() + _2x.widen_mul(f);
Self { m, r, _2xq }
}

/// Extract the current remainder in the range `[0, 2n)`
fn partial_remainder(&self) -> U {
// RR/2 = qm + r, 0 <= r < m
// 2xq = uR + v, 0 <= v < R
// muR = 2mxq - mv
// = xRR - 2xr - mv
// mu + (2xr + mv)/R == xR

// 0 <= 2xq < RR
// R <= q < 2R
// 0 <= x < R/2
// R/4 < m < R/2
// 0 <= r < m
// 0 <= mv < mR
// 0 <= 2xr < rR < mR

// 0 <= (2xr + mv)/R < 2m
// Add `mu` to each term to obtain:
// mu <= xR < mu + 2m

// Since `0 <= 2m < R`, `xR` is the only multiple of `R` between
// `mu` and `m(u+2)`, so the high half of `m(u+2)` must equal `x`.
let _1 = U::ONE;
self.m.widen_mul(self._2xq.hi() + (_1 + _1)).hi()
}

/// Replace the remainder `x` with `(x << k) - un`,
/// for a suitable quotient `u`, which is returned.
fn shift_reduce(&mut self, k: u32) -> U {
assert!(k < U::BITS);
// 2xq << k = aRR/2 + b;
let a = self._2xq.hi() >> (U::BITS - 1 - k);
let (low, high) = (self._2xq << k).lo_hi();
let b = U::D::from_lo_hi(low, high & (U::MAX >> 1));

// (2xq << k) - aqm
// = aRR/2 + b - aqm
// = a(RR/2 - qm) + b
// = ar + b
self._2xq = a.widen_mul(self.r) + b;
a
}

/// Replace the remainder `x` with `x(R/2) - un`,
/// for a suitable quotient `u`, which is returned.
fn word_reduce(&mut self) -> U {
// 2xq = uR + v
let (v, u) = self._2xq.lo_hi();
// xqR - uqm
// = uRR/2 + vR/2 - uRR/2 + ur
// = ur + (v/2)R
self._2xq = u.widen_mul(self.r) + U::widen_hi(v >> 1);
u
}
}

#[cfg(test)]
mod test {
use crate::support::linear_mul_reduction;
use crate::support::modular::Reducer;

#[test]
fn reducer_ops() {
for n in 33..=63_u8 {
for x in 0..2 * n {
let temp = Reducer::new(x, n);
let n = n as u32;
let x0 = temp.partial_remainder() as u32;
assert_eq!(x as u32, x0);
for k in 0..=7 {
let mut red = temp.clone();
let u = red.shift_reduce(k) as u32;
let x1 = red.partial_remainder() as u32;
assert_eq!(x1, (x0 << k) - u * n);
assert!(x1 < 2 * n);
assert!((red._2xq as u32).is_multiple_of(2 * x1));

// `word_reduce` is equivalent to
// `shift_reduce(U::BITS - 1)`
if k == 7 {
let mut alt = temp.clone();
let w = alt.word_reduce();
assert_eq!(u, w as u32);
assert_eq!(alt, red);
}
}
}
}
}
#[test]
fn reduction_u8() {
for y in 1..64u8 {
for x in 0..2 * y {
let mut r = x % y;
for e in 0..100 {
assert_eq!(r, linear_mul_reduction(x, e, y));
// maintain the correct expected remainder
r <<= 1;
if r >= y {
r -= y;
}
}
}
}
}
#[test]
fn reduction_u128() {
assert_eq!(
linear_mul_reduction::<u128>(17, 100, 123456789),
(17 << 100) % 123456789
);

// power-of-two divisor
assert_eq!(
linear_mul_reduction(0xdead_beef, 100, 1_u128 << 116),
0xbeef << 100
);

let x = 10_u128.pow(37);
let y = 11_u128.pow(36);
assert!(x < y);
let mut r = x;
for e in 0..1000 {
assert_eq!(r, linear_mul_reduction(x, e, y));
// maintain the correct expected remainder
r <<= 1;
if r >= y {
r -= y;
}
assert!(r != 0);
}
}
}
Loading