| /* SPDX-License-Identifier: MIT OR Apache-2.0 */ |
| |
| //! This module provides accelerated modular multiplication by large powers |
| //! of two, which is needed for computing floating point remainders in `fmod` |
| //! and similar functions. |
| //! |
| //! To keep the equations somewhat concise, the following conventions are used: |
| //! - all integer operations are in the mathematical sense, without overflow |
| //! - concatenation means multiplication: `2xq = 2 * x * q` |
| //! - `R = (1 << U::BITS)` is the modulus of wrapping arithmetic in `U` |
| |
| use crate::support::int_traits::NarrowingDiv; |
| use crate::support::{DInt, HInt, Int}; |
| |
| /// Compute the remainder `(x << e) % y` with unbounded integers. |
| /// Requires `x < 2y` and `y.leading_zeros() >= 2` |
| pub fn linear_mul_reduction<U>(x: U, mut e: u32, mut y: U) -> U |
| where |
| U: HInt + Int<Unsigned = U>, |
| U::D: NarrowingDiv, |
| { |
| assert!(y <= U::MAX >> 2); |
| assert!(x < (y << 1)); |
| let _0 = U::ZERO; |
| let _1 = U::ONE; |
| |
| // power of two divisors |
| if (y & (y - _1)).is_zero() { |
| if e < U::BITS { |
| // shift and only keep low bits |
| return (x << e) & (y - _1); |
| } else { |
| // would shift out all the bits |
| return _0; |
| } |
| } |
| |
| // Use the identity `(x << e) % y == ((x << (e + s)) % (y << s)) >> s` |
| // to shift the divisor so it has exactly two leading zeros to satisfy |
| // the precondition of `Reducer::new` |
| let s = y.leading_zeros() - 2; |
| e += s; |
| y <<= s; |
| |
| // `m: Reducer` keeps track of the remainder `x` in a form that makes it |
| // very efficient to do `x <<= k` modulo `y` for integers `k < U::BITS` |
| let mut m = Reducer::new(x, y); |
| |
| // Use the faster special case with constant `k == U::BITS - 1` while we can |
| while e >= U::BITS - 1 { |
| m.word_reduce(); |
| e -= U::BITS - 1; |
| } |
| // Finish with the variable shift operation |
| m.shift_reduce(e); |
| |
| // The partial remainder is in `[0, 2y)` ... |
| let r = m.partial_remainder(); |
| // ... so check and correct, and compensate for the earlier shift. |
| r.checked_sub(y).unwrap_or(r) >> s |
| } |
| |
| /// Helper type for computing the reductions. The implementation has a number |
| /// of seemingly weird choices, but everything is aimed at streamlining |
| /// `Reducer::word_reduce` into its current form. |
| /// |
| /// Implicitly contains: |
| /// n in (R/8, R/4) |
| /// x in [0, 2n) |
| /// The value of `n` is fixed for a given `Reducer`, |
| /// but the value of `x` is modified by the methods. |
| #[derive(Debug, Clone, PartialEq, Eq)] |
| struct Reducer<U: HInt> { |
| // m = 2n |
| m: U, |
| // q = (RR/2) / m |
| // r = (RR/2) % m |
| // Then RR/2 = qm + r, where `0 <= r < m` |
| // The value `q` is only needed during construction, so isn't saved. |
| r: U, |
| // The value `x` is implicitly stored as `2 * q * x`: |
| _2xq: U::D, |
| } |
| |
| impl<U> Reducer<U> |
| where |
| U: HInt, |
| U: Int<Unsigned = U>, |
| { |
| /// Construct a reducer for `(x << _) mod n`. |
| /// |
| /// Requires `R/8 < n < R/4` and `x < 2n`. |
| fn new(x: U, n: U) -> Self |
| where |
| U::D: NarrowingDiv, |
| { |
| let _1 = U::ONE; |
| assert!(n > (_1 << (U::BITS - 3))); |
| assert!(n < (_1 << (U::BITS - 2))); |
| let m = n << 1; |
| assert!(x < m); |
| |
| // We need to compute the parameters |
| // `q = (RR/2) / m` |
| // `r = (RR/2) % m` |
| |
| // Since `m` is in `(R/4, R/2)`, the quotient `q` is in `[R, 2R)`, and |
| // it would overflow in `U` if computed directly. Instead, we compute |
| // `f = q - R`, which is in `[0, R)`. To do so, we simply subtract `Rm` |
| // from the dividend, which doesn't change the remainder: |
| // `f = R(R/2 - m) / m` |
| // `r = R(R/2 - m) % m` |
| let dividend = ((_1 << (U::BITS - 1)) - m).widen_hi(); |
| let (f, r) = dividend.checked_narrowing_div_rem(m).unwrap(); |
| |
| // As `x < m`, `xq < qm <= RR/2` |
| // Thus `2xq = 2xR + 2xf` does not overflow in `U::D`. |
| let _2x = x + x; |
| let _2xq = _2x.widen_hi() + _2x.widen_mul(f); |
| Self { m, r, _2xq } |
| } |
| |
| /// Extract the current remainder `x` in the range `[0, 2n)` |
| fn partial_remainder(&self) -> U { |
| // `RR/2 = qm + r`, where `0 <= r < m` |
| // `2xq = uR + v`, where `0 <= v < R` |
| |
| // The goal is to extract the current value of `x` from the value `2xq` |
| // that we actually have. A bit simplified, we could multiply it by `m` |
| // to obtain `2xqm == 2x(RR/2 - r) == xRR - 2xr`, where `2xr < RR`. |
| // We could just round that up to the next multiple of `RR` to get `x`, |
| // but we can avoid having to multiply the full double-wide `2xq` by |
| // making a couple of adjustments: |
| |
| // First, let's only use the high half `u` for the product, and |
| // include an additional error term due to the truncation: |
| // `mu = xR - (2xr + mv)/R` |
| |
| // Next, show bounds for the error term |
| // `0 <= mv < mR` follows from `0 <= v < R` |
| // `0 <= 2xr < mR` follows from `0 <= x < m < R/2` and `0 <= r < m` |
| // Adding those together, we have: |
| // `0 <= (mv + 2xr)/R < 2m` |
| // Which also implies: |
| // `0 < 2m - (mv + 2xr)/R <= 2m < R` |
| |
| // For that reason, we can use `u + 2` as the factor to obtain |
| // `m(u + 2) = xR + (2m - (mv + 2xr)/R)` |
| // By the previous inequality, the second term fits neatly in the lower |
| // half, so we get exactly `x` as the high half. |
| let u = self._2xq.hi(); |
| let _2 = U::ONE + U::ONE; |
| self.m.widen_mul(u + _2).hi() |
| |
| // Additionally, we should ensure that `u + 2` cannot overflow: |
| // Since `x < m` and `2qm <= RR`, |
| // `2xq <= 2q(m-1) <= RR - 2q` |
| // As we also have `q > R`, |
| // `2xq < RR - 2R` |
| // which is sufficient. |
| } |
| |
| /// Replace the remainder `x` with `(x << k) - un`, |
| /// for a suitable quotient `u`, which is returned. |
| /// |
| /// Requires that `k < U::BITS`. |
| fn shift_reduce(&mut self, k: u32) -> U { |
| assert!(k < U::BITS); |
| |
| // First, split the shifted value: |
| // `2xq << k = aRR/2 + b`, where `0 <= b < RR/2` |
| let a = self._2xq.hi() >> (U::BITS - 1 - k); |
| let (low, high) = (self._2xq << k).lo_hi(); |
| let b = U::D::from_lo_hi(low, high & (U::MAX >> 1)); |
| |
| // Then, subtract `2anq = aqm`: |
| // ``` |
| // (2xq << k) - aqm |
| // = aRR/2 + b - aqm |
| // = a(RR/2 - qm) + b |
| // = ar + b |
| // ``` |
| self._2xq = a.widen_mul(self.r) + b; |
| a |
| |
| // Since `a` is at most the high half of `2xq`, we have |
| // `a + 2 < R` (shown above, in `partial_remainder`) |
| // Using that together with `b < RR/2` and `r < m < R/2`, |
| // we get `(a + 2)r + b < RR`, so |
| // `ar + b < RR - 2r = 2mq` |
| // which shows that the new remainder still satisfies `x < m`. |
| } |
| |
| // NB: `word_reduce()` is just the special case `shift_reduce(U::BITS - 1)` |
| // that optimizes especially well. The correspondence is that `a == u` and |
| // `b == (v >> 1).widen_hi()` |
| // |
| /// Replace the remainder `x` with `x(R/2) - un`, |
| /// for a suitable quotient `u`, which is returned. |
| fn word_reduce(&mut self) -> U { |
| // To do so, we replace `2xq = uR + v` with |
| // ``` |
| // 2 * (x(R/2) - un) * q |
| // = xqR - 2unq |
| // = xqR - uqm |
| // = uRR/2 + vR/2 - uRR/2 + ur |
| // = ur + (v/2)R |
| // ``` |
| let (v, u) = self._2xq.lo_hi(); |
| self._2xq = u.widen_mul(self.r) + U::widen_hi(v >> 1); |
| u |
| |
| // Additional notes: |
| // 1. As `v` is the low bits of `2xq`, it is even and can be halved. |
| // 2. The new remainder is `(xr + mv/2) / R` (see below) |
| // and since `v < R`, `r < m`, `x < m < R/2`, |
| // that is also strictly less than `m`. |
| // ``` |
| // (x(R/2) - un)R |
| // = xRR/2 - (m/2)uR |
| // = x(qm + r) - (m/2)(2xq - v) |
| // = xqm + xr - xqm + mv/2 |
| // = xr + mv/2 |
| // ``` |
| } |
| } |
| |
| #[cfg(test)] |
| mod test { |
| use crate::support::linear_mul_reduction; |
| use crate::support::modular::Reducer; |
| |
| #[test] |
| fn reducer_ops() { |
| for n in 33..=63_u8 { |
| for x in 0..2 * n { |
| let temp = Reducer::new(x, n); |
| let n = n as u32; |
| let x0 = temp.partial_remainder() as u32; |
| assert_eq!(x as u32, x0); |
| for k in 0..=7 { |
| let mut red = temp.clone(); |
| let u = red.shift_reduce(k) as u32; |
| let x1 = red.partial_remainder() as u32; |
| assert_eq!(x1, (x0 << k) - u * n); |
| assert!(x1 < 2 * n); |
| assert!((red._2xq as u32).is_multiple_of(2 * x1)); |
| |
| // `word_reduce` is equivalent to |
| // `shift_reduce(U::BITS - 1)` |
| if k == 7 { |
| let mut alt = temp.clone(); |
| let w = alt.word_reduce(); |
| assert_eq!(u, w as u32); |
| assert_eq!(alt, red); |
| } |
| } |
| } |
| } |
| } |
| #[test] |
| fn reduction_u8() { |
| for y in 1..64u8 { |
| for x in 0..2 * y { |
| let mut r = x % y; |
| for e in 0..100 { |
| assert_eq!(r, linear_mul_reduction(x, e, y)); |
| // maintain the correct expected remainder |
| r <<= 1; |
| if r >= y { |
| r -= y; |
| } |
| } |
| } |
| } |
| } |
| #[test] |
| fn reduction_u128() { |
| assert_eq!( |
| linear_mul_reduction::<u128>(17, 100, 123456789), |
| (17 << 100) % 123456789 |
| ); |
| |
| // power-of-two divisor |
| assert_eq!( |
| linear_mul_reduction(0xdead_beef, 100, 1_u128 << 116), |
| 0xbeef << 100 |
| ); |
| |
| let x = 10_u128.pow(37); |
| let y = 11_u128.pow(36); |
| assert!(x < y); |
| let mut r = x; |
| for e in 0..1000 { |
| assert_eq!(r, linear_mul_reduction(x, e, y)); |
| // maintain the correct expected remainder |
| r <<= 1; |
| if r >= y { |
| r -= y; |
| } |
| assert!(r != 0); |
| } |
| } |
| } |