diff --git a/src/bignum/bignum.cc b/src/bignum/bignum.cc index 3cfe26b..c3588a9 100644 --- a/src/bignum/bignum.cc +++ b/src/bignum/bignum.cc @@ -179,6 +179,23 @@ digits_type do_pow(digits_type lhs, digits_type rhs) { return lhs; } +digits_type do_sqrt(digits_type const& num) { + digits_type one; + one.push_back(1); + digits_type max = num; + digits_type min = do_addition(max, one); + min = do_halve(min); + + while (do_less_than(min, max)) { + max = min; + std::tie(min, std::ignore) = do_div_mod(num, max); + min = do_addition(min, max); + min = do_halve(min); + } + + return max; +} + } // namespace BigNum::BigNum(std::int64_t number) { @@ -431,4 +448,24 @@ BigNum pow(BigNum const& lhs, BigNum const& rhs) { return res; } +BigNum sqrt(BigNum const& num) { + assert(num.is_canonicalized()); + + if (num.is_zero()) { + return BigNum(); + } else if (num.is_negative()) { + throw std::invalid_argument( + "attempt to take the square root of a negative number"); + } + + BigNum res; + + res.digits_ = do_sqrt(num.digits_); + res.sign_ = 1; + + assert(res.is_canonicalized()); + + return res; +} + } // namespace abacus::bignum diff --git a/src/bignum/bignum.hh b/src/bignum/bignum.hh index 59a660e..1880af0 100644 --- a/src/bignum/bignum.hh +++ b/src/bignum/bignum.hh @@ -89,6 +89,8 @@ public: friend BigNum pow(BigNum const& lhs, BigNum const& rhs); + friend BigNum sqrt(BigNum const& num); + friend bool operator==(BigNum const& lhs, BigNum const& rhs) { return lhs.equal(rhs); } diff --git a/tests/unit/bignum.cc b/tests/unit/bignum.cc index 86c6178..8040ed0 100644 --- a/tests/unit/bignum.cc +++ b/tests/unit/bignum.cc @@ -338,3 +338,41 @@ TEST(BigNum, pow_negative) { EXPECT_EQ(pow(minus_three, three), minus_twenty_seven); EXPECT_EQ(pow(three, four), eighty_one); } + +TEST(BigNum, sqrt_zero) { + auto const zero = BigNum(0); + + EXPECT_EQ(sqrt(zero), zero); +} + +TEST(BigNum, sqrt_one) { + auto const one = BigNum(1); + + EXPECT_EQ(sqrt(one), one); +} + +TEST(BigNum, sqrt_truncation) { + auto const one = BigNum(1); + auto const two = BigNum(2); + auto const three = BigNum(3); + + EXPECT_EQ(sqrt(two), one); + EXPECT_EQ(sqrt(three), one); +} + +TEST(BigNum, sqrt) { + auto const two = BigNum(2); + auto const three = BigNum(3); + auto const four = BigNum(4); + auto const nine = BigNum(9); + auto const ten = BigNum(10); + auto const eighty_one = BigNum(81); + auto const ninety_nine = BigNum(99); + auto const hundred = BigNum(100); + + EXPECT_EQ(sqrt(four), two); + EXPECT_EQ(sqrt(nine), three); + EXPECT_EQ(sqrt(eighty_one), nine); + EXPECT_EQ(sqrt(ninety_nine), nine); + EXPECT_EQ(sqrt(hundred), ten); +}