diff --git a/src/algorithms.rs b/src/algorithms.rs index c65f3b4f..10f7f22f 100644 --- a/src/algorithms.rs +++ b/src/algorithms.rs @@ -266,6 +266,35 @@ pub(crate) fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) { } } +/// Subtract a multiple. +/// a -= b * c +/// Returns a borrow (if a < b then borrow > 0). +fn sub_mul_digit_same_len(a: &mut [BigDigit], b: &[BigDigit], c: BigDigit) -> BigDigit { + debug_assert!(a.len() == b.len()); + + // carry is between -big_digit::MAX and 0, so to avoid overflow we store + // offset_carry = carry + big_digit::MAX + let mut offset_carry = big_digit::MAX; + + for (x, y) in a.iter_mut().zip(b) { + // We want to calculate sum = x - y * c + carry. + // sum >= -(big_digit::MAX * big_digit::MAX) - big_digit::MAX + // sum <= big_digit::MAX + // Offsetting sum by (big_digit::MAX << big_digit::BITS) puts it in DoubleBigDigit range. + let offset_sum = big_digit::to_doublebigdigit(big_digit::MAX, *x) + - big_digit::MAX as DoubleBigDigit + + offset_carry as DoubleBigDigit + - *y as DoubleBigDigit * c as DoubleBigDigit; + + let (new_offset_carry, new_x) = big_digit::from_doublebigdigit(offset_sum); + offset_carry = new_offset_carry; + *x = new_x; + } + + // Return the borrow. + big_digit::MAX - offset_carry +} + fn bigint_from_slice(slice: &[BigDigit]) -> BigInt { BigInt::from(biguint_from_vec(slice.to_vec())) } @@ -631,78 +660,99 @@ pub(crate) fn div_rem_ref(u: &BigUint, d: &BigUint) -> (BigUint, BigUint) { (q, r >> shift) } -/// an implementation of Knuth, TAOCP vol 2 section 4.3, algorithm D -/// -/// # Correctness -/// -/// This function requires the following conditions to run correctly and/or effectively -/// -/// - `a > b` -/// - `d.data.len() > 1` -/// - `d.data.last().unwrap().leading_zeros() == 0` +/// An implementation of the base division algorithm. +/// Knuth, TAOCP vol 2 section 4.3.1, algorithm D, with an improvement from exercises 19-21. fn div_rem_core(mut a: BigUint, b: &BigUint) -> (BigUint, BigUint) { - // The algorithm works by incrementally calculating "guesses", q0, for part of the - // remainder. Once we have any number q0 such that q0 * b <= a, we can set + debug_assert!( + a.data.len() >= b.data.len() + && b.data.len() > 1 + && b.data.last().unwrap().leading_zeros() == 0 + ); + + // The algorithm works by incrementally calculating "guesses", q0, for the next digit of the + // quotient. Once we have any number q0 such that (q0 << j) * b <= a, we can set // - // q += q0 - // a -= q0 * b + // q += q0 << j + // a -= (q0 << j) * b // // and then iterate until a < b. Then, (q, a) will be our desired quotient and remainder. // - // q0, our guess, is calculated by dividing the last few digits of a by the last digit of b - // - this should give us a guess that is "close" to the actual quotient, but is possibly - // greater than the actual quotient. If q0 * b > a, we simply use iterated subtraction - // until we have a guess such that q0 * b <= a. + // q0, our guess, is calculated by dividing the last three digits of a by the last two digits of + // b - this will give us a guess that is close to the actual quotient, but is possibly greater. + // It can only be greater by 1 and only in rare cases, with probability at most + // 2^-(big_digit::BITS-1) for random a, see TAOCP 4.3.1 exercise 21. // + // If the quotient turns out to be too large, we adjust it by 1: + // q -= 1 << j + // a += b << j + + // a0 stores an additional extra most significant digit of the dividend, not stored in a. + let mut a0 = 0; + + // [b1, b0] are the two most significant digits of the divisor. They never change. + let b0 = *b.data.last().unwrap(); + let b1 = b.data[b.data.len() - 2]; - let bn = *b.data.last().unwrap(); let q_len = a.data.len() - b.data.len() + 1; let mut q = BigUint { data: vec![0; q_len], }; - // We reuse the same temporary to avoid hitting the allocator in our inner loop - this is - // sized to hold a0 (in the common case; if a particular digit of the quotient is zero a0 - // can be bigger). - // - let mut tmp = BigUint { - data: Vec::with_capacity(2), - }; - for j in (0..q_len).rev() { - // When calculating our next guess q0, we don't need to consider the digits below j - // + b.data.len() - 1: we're guessing digit j of the quotient (i.e. q0 << j) from - // digit bn of the divisor (i.e. bn << (b.data.len() - 1) - so the product of those - // two numbers will be zero in all digits up to (j + b.data.len() - 1). - let offset = j + b.data.len() - 1; - if offset >= a.data.len() { - continue; + debug_assert!(a.data.len() == b.data.len() + j); + + let a1 = *a.data.last().unwrap(); + let a2 = a.data[a.data.len() - 2]; + + // The first q0 estimate is [a1,a0] / b0. It will never be too small, it may be too large + // by at most 2. + let (mut q0, mut r) = if a0 < b0 { + let (q0, r) = div_wide(a0, a1, b0); + (q0, r as DoubleBigDigit) + } else { + debug_assert!(a0 == b0); + // Avoid overflowing q0, we know the quotient fits in BigDigit. + // [a1,a0] = b0 * (1< a0 { + // q0 is too large. We need to add back one multiple of b. + q0 -= 1; + borrow -= __add2(&mut a.data[j..], &b.data); } + // The top digit of a, stored in a0, has now been zeroed. + debug_assert!(borrow == a0); - add2(&mut q.data[j..], &q0.data[..]); - sub2(&mut a.data[j..], &prod.data[..]); - a.normalize(); + q.data[j] = q0; - tmp = q0; + // Pop off the next top digit of a. + a0 = a.data.pop().unwrap(); } + a.data.push(a0); + a.normalize(); + debug_assert!(a < *b); (q.normalized(), a) diff --git a/src/lib.rs b/src/lib.rs index 30c5abec..ab2bd4e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -266,6 +266,7 @@ mod big_digit { pub(crate) const HALF: BigDigit = (1 << HALF_BITS) - 1; const LO_MASK: DoubleBigDigit = (1 << BITS) - 1; + pub(crate) const MAX: BigDigit = LO_MASK as BigDigit; #[inline] fn get_hi(n: DoubleBigDigit) -> BigDigit { diff --git a/tests/biguint.rs b/tests/biguint.rs index e14dbdd4..14c33daa 100644 --- a/tests/biguint.rs +++ b/tests/biguint.rs @@ -897,6 +897,20 @@ fn test_div_rem() { } } +#[test] +fn test_div_rem_big_multiple() { + let a = BigUint::from(3u32).pow(100u32); + let a2 = &a * &a; + + let (div, rem) = a2.div_rem(&a); + assert_eq!(div, a); + assert!(rem.is_zero()); + + let (div, rem) = (&a2 - 1u32).div_rem(&a); + assert_eq!(div, &a - 1u32); + assert_eq!(rem, &a - 1u32); +} + #[test] fn test_div_ceil() { fn check(a: &BigUint, b: &BigUint, d: &BigUint, m: &BigUint) {