#include "bignum.hh" #include #include #include #include #include #include namespace abacus::bignum { using digits_type = std::vector; namespace { auto static constexpr BASE = 10; bool do_less_than(digits_type const& lhs, digits_type const& rhs) { if (lhs.size() != rhs.size()) { return lhs.size() < rhs.size(); } return std::lexicographical_compare(lhs.rbegin(), lhs.rend(), rhs.rbegin(), rhs.rend()); } void trim_leading_zeros(digits_type& num) { auto const it = std::find_if(num.rbegin(), num.rend(), [](auto v) { return v != 0; }); num.erase(it.base(), num.end()); } std::ostream& do_dump(digits_type const& num, std::ostream& out) { std::copy(num.rbegin(), num.rend(), std::ostream_iterator(out)); return out; } // More optimised than full-on div_mod digits_type do_halve(digits_type num) { assert(num.size() != 0); int carry = 0; for (auto i = num.rbegin(); i != num.rend(); ++i) { auto const was_odd = (*i % 2) == 1; *i /= 2; *i += carry; if (was_odd) { carry = BASE / 2; } else { carry = 0; } } trim_leading_zeros(num); return num; } bool is_odd(digits_type const& num) { if (num.size() == 0) { return false; } return (num.front() % 2) == 1; } digits_type do_addition(digits_type const& lhs, digits_type const& rhs) { int carry = 0; digits_type res; auto it1 = lhs.begin(); auto it2 = rhs.begin(); auto const end1 = lhs.end(); auto const end2 = rhs.end(); while (it1 != end1 && it2 != end2) { int addition = *it1 + *it2 + carry; carry = addition / BASE; res.push_back(addition % BASE); ++it1; ++it2; } auto leftover = [=]() { if (it1 != end1) { return std::span(it1, end1); } return std::span(it2, end2); }(); for (auto value : leftover) { int addition = value + carry; carry = addition / BASE; res.push_back(addition % BASE); } if (carry != 0) { res.push_back(carry); } return res; } digits_type do_substraction(digits_type const& lhs, digits_type const& rhs) { assert(!do_less_than(lhs, rhs)); digits_type complement; auto const take_complement = [](auto num) { return 9 - num; }; std::transform(lhs.begin(), lhs.end(), std::back_inserter(complement), take_complement); complement = do_addition(complement, rhs); std::transform(complement.begin(), complement.end(), complement.begin(), take_complement); trim_leading_zeros(complement); return complement; } digits_type do_multiplication(digits_type const& lhs, digits_type const& rhs) { digits_type res(lhs.size() + rhs.size()); for (std::size_t i = 0; i < lhs.size(); ++i) { int carry = 0; for (std::size_t j = 0; j < rhs.size(); ++j) { int multiplication = lhs[i] * rhs[j]; res[i + j] += multiplication + carry; carry = res[i + j] / BASE; res[i + j] %= BASE; } res[i + rhs.size()] += carry; } return res; } std::pair do_div_mod(digits_type const& lhs, digits_type const& rhs) { if (rhs.size() == 0) { throw std::invalid_argument("attempt to divide by zero"); } digits_type multiple = rhs; digits_type rank; rank.push_back(1); while (!do_less_than(lhs, multiple)) { multiple = do_addition(multiple, multiple); rank = do_addition(rank, rank); } digits_type quotient; digits_type remainder = lhs; while (!do_less_than(remainder, rhs)) { while (do_less_than(remainder, multiple)) { multiple = do_halve(multiple); rank = do_halve(rank); } assert(!do_less_than(multiple, rhs)); quotient = do_addition(quotient, rank); remainder = do_substraction(remainder, multiple); } return std::make_pair(quotient, remainder); } digits_type do_pow(digits_type lhs, digits_type rhs) { assert(rhs.size() != 0); auto original = lhs; while (rhs.size() != 0 && !(rhs.size() == 1 && rhs.front() == 1)) { lhs = do_multiplication(lhs, lhs); if (is_odd(rhs)) { lhs = do_multiplication(lhs, original); } rhs = do_halve(rhs); } return lhs; } digits_type do_sqrt(digits_type const& num) { digits_type one; one.push_back(1); digits_type max = num; digits_type min = do_addition(max, one); min = do_halve(min); while (do_less_than(min, max)) { max = min; std::tie(min, std::ignore) = do_div_mod(num, max); min = do_addition(min, max); min = do_halve(min); } return max; } } // namespace BigNum::BigNum(std::int64_t number) { if (number == 0) { return; } if (number < 0) { sign_ = -1; } else { sign_ = 1; } auto abs = static_cast(std::abs(number)); do { digits_.push_back(abs % BASE); abs /= BASE; } while (abs); assert(is_canonicalized()); } std::ostream& BigNum::dump(std::ostream& out) const { if (is_zero()) { return out << '0'; } if (is_negative()) { out << '-'; } return do_dump(digits_, out); } std::istream& BigNum::read(std::istream& in) { bool parsed = false; bool leading = true; if (in.peek() == '-') { in.get(); sign_ = -1; } else { sign_ = 1; } digits_type digits; while (std::isdigit(in.peek())) { parsed = true; int digit = in.get() - '0'; if (digit != 0 || !leading) { digits.push_back(digit); leading = false; } } if (leading) { sign_ = 0; } if (!parsed) { in.setstate(std::ios::failbit); } else { std::reverse(digits.begin(), digits.end()); digits_ = std::move(digits); } return in; } void BigNum::flip_sign() { assert(is_canonicalized()); sign_ *= -1; } void BigNum::add(BigNum const& rhs) { assert(is_canonicalized()); assert(rhs.is_canonicalized()); if (rhs.is_zero()) { return; } if (is_zero()) { *this = rhs; return; } if (sign_ == rhs.sign_) { digits_ = do_addition(digits_, rhs.digits_); } else { bool flipped = do_less_than(digits_, rhs.digits_); if (flipped) { digits_ = do_substraction(rhs.digits_, digits_); } else { digits_ = do_substraction(digits_, rhs.digits_); } if (flipped) { flip_sign(); } canonicalize(); } assert(is_canonicalized()); } void BigNum::substract(BigNum const& rhs) { assert(is_canonicalized()); assert(rhs.is_canonicalized()); flip_sign(); add(rhs); flip_sign(); assert(is_canonicalized()); } void BigNum::multiply(BigNum const& rhs) { assert(is_canonicalized()); assert(rhs.is_canonicalized()); if (is_zero() || rhs.is_zero()) { *this = BigNum(); return; } digits_ = do_multiplication(digits_, rhs.digits_); sign_ *= rhs.sign_; canonicalize(); } void BigNum::divide(BigNum const& rhs) { std::tie(*this, std::ignore) = div_mod(*this, rhs); } void BigNum::modulo(BigNum const& rhs) { std::tie(std::ignore, *this) = div_mod(*this, rhs); } bool BigNum::equal(BigNum const& rhs) const { assert(is_canonicalized()); assert(rhs.is_canonicalized()); if (sign_ != rhs.sign_) { return false; } return digits_ == rhs.digits_; } bool BigNum::less_than(BigNum const& rhs) const { assert(is_canonicalized()); assert(rhs.is_canonicalized()); if (sign_ != rhs.sign_) { return sign_ < rhs.sign_; } if (is_positive()) { return do_less_than(digits_, rhs.digits_); } else { return do_less_than(rhs.digits_, digits_); } } void BigNum::canonicalize() { trim_leading_zeros(digits_); if (digits_.size() == 0) { sign_ = 0; } assert(is_canonicalized()); } bool BigNum::is_canonicalized() const { if (digits_.size() == 0) { return sign_ == 0; } // `back` is valid since there is at least one element auto const has_leading_zero = digits_.back() == 0; if (has_leading_zero) { return false; } auto const has_overflow = std::any_of(digits_.begin(), digits_.end(), [](auto v) { return v >= BASE; }); if (has_overflow) { return false; } return true; } bool BigNum::is_zero() const { assert(is_canonicalized()); return sign_ == 0; } bool BigNum::is_positive() const { assert(is_canonicalized()); return sign_ >= 0; } bool BigNum::is_negative() const { assert(is_canonicalized()); return sign_ <= 0; } std::pair div_mod(BigNum const& lhs, BigNum const& rhs) { assert(lhs.is_canonicalized()); assert(rhs.is_canonicalized()); if (lhs.is_zero()) { return std::make_pair(BigNum(), BigNum()); } auto quotient = BigNum(0); auto remainder = BigNum(0); std::tie(quotient.digits_, remainder.digits_) = do_div_mod(lhs.digits_, rhs.digits_); // Respect the identity `(a/b)*b + (a%b) = a` quotient.sign_ = lhs.sign_ * rhs.sign_; remainder.sign_ = lhs.sign_; quotient.canonicalize(); remainder.canonicalize(); return std::make_pair(quotient, remainder); } BigNum pow(BigNum const& lhs, BigNum const& rhs) { assert(lhs.is_canonicalized()); assert(rhs.is_canonicalized()); if (rhs.is_zero()) { return BigNum(1); } else if (rhs.is_negative()) { return BigNum(); } auto res = BigNum(0); res.digits_ = do_pow(lhs.digits_, rhs.digits_); res.sign_ = is_odd(rhs.digits_) ? lhs.sign_ : 1; res.canonicalize(); return res; } BigNum sqrt(BigNum const& num) { assert(num.is_canonicalized()); if (num.is_zero()) { return BigNum(); } else if (num.is_negative()) { throw std::invalid_argument( "attempt to take the square root of a negative number"); } auto res = BigNum(0); res.digits_ = do_sqrt(num.digits_); res.sign_ = 1; assert(res.is_canonicalized()); return res; } BigNum log2(BigNum const& num) { assert(num.is_canonicalized()); if (num.is_zero()) { throw std::invalid_argument("attempt to take the log2 of zero"); } else if (num.is_negative()) { throw std::invalid_argument( "attempt to take the log2 of a negative number"); } auto tmp = num; auto res = BigNum(0); auto one = BigNum(1); while (tmp > one) { tmp.digits_ = do_halve(tmp.digits_); res += one; } assert(res.is_canonicalized()); return res; } BigNum log10(BigNum const& num) { assert(num.is_canonicalized()); assert(BASE == 10); if (num.is_zero()) { throw std::invalid_argument("attempt to take the log10 of zero"); } else if (num.is_negative()) { throw std::invalid_argument( "attempt to take the log10 of a negative number"); } auto res = BigNum(num.digits_.size() - 1); assert(res.is_canonicalized()); return res; } } // namespace abacus::bignum