Add methods to mpz, mpq, mpbq

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-07-17 14:24:35 -07:00
parent 88b49ec21f
commit d028041135
7 changed files with 307 additions and 26 deletions

View file

@ -58,11 +58,18 @@ void tst4() {
}
}
void tst5() {
lean::mpbq n(7,4);
std::cout << lean::mpbq::decimal(n) << "\n";
lean::mpq r(4, 3);
std::cout << lean::mpq::decimal(r) << "\n";
}
int main() {
std::cout << "Lean (version " << LEAN_VERSION_MAJOR << "." << LEAN_VERSION_MINOR << ")\n";
tst1();
tst2();
tst3();
tst5();
std::cout << "done\n";
return 0;
}

View file

@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <limits.h>
#include "mpbq.h"
namespace lean {
@ -39,14 +40,30 @@ int cmp(mpbq const & a, mpbq const & b) {
int cmp(mpbq const & a, mpz const & b) {
static thread_local mpz tmp;
if (a.m_k == 0)
if (a.m_k == 0) {
return cmp(a.m_num, b);
}
else {
mul2k(tmp, b, a.m_k);
return cmp(a.m_num, tmp);
}
}
int cmp(mpbq const & a, mpq const & b) {
if (a.is_integer() && b.is_integer()) {
return a.m_num < b;
}
else {
static thread_local mpz tmp1;
static thread_local mpz tmp2;
// tmp1 <- numerator(a)*denominator(b)
denominator(tmp1, b); tmp1 *= a.m_num;
// tmp2 <- numerator(b)*denominator(a)
numerator(tmp2, b); mul2k(tmp2, tmp2, a.m_k);
return tmp1 < tmp2;
}
}
mpbq & mpbq::operator+=(mpbq const & a) {
if (m_k == a.m_k) {
m_num += a.m_num;
@ -142,6 +159,15 @@ mpbq & mpbq::mul_int(T const & a) {
mpbq & mpbq::operator*=(unsigned a) { return mul_int<unsigned>(a); }
mpbq & mpbq::operator*=(int a) { return mul_int<int>(a); }
void power(mpbq & a, mpbq const & b, unsigned k) {
lean_assert(static_cast<unsigned long long>(k) * static_cast<unsigned long long>(b.m_k) <= static_cast<unsigned long long>(UINT_MAX));
// We don't need to normalize because:
// If b.m_k == 0, then b is an integer, and the result be an integer
// If b.m_k > 0, then b.m_num must be odd, and the (b.m_num)^k will also be odd
a.m_k = b.m_k * k;
power(a.m_num, b.m_num, k);
}
int mpbq::magnitude_lb() const {
int s = m_num.sgn();
if (s < 0) {
@ -192,6 +218,96 @@ void mul2k(mpbq & a, unsigned k) {
}
}
bool root_lower(mpbq & a, mpbq const & b, unsigned n) {
bool r = root(a.m_num, b.m_num, n);
if (!r)
--a.m_num;
if (b.m_k % n == 0) {
a.m_k = b.m_k / n;
a.normalize();
return r;
}
else if (a.m_num.is_neg()) {
a.m_k = b.m_k / n;
a.normalize();
return false;
}
else {
a.m_k = b.m_k / n;
a.m_k++;
a.normalize();
return false;
}
}
bool root_upper(mpbq & a, mpbq const & b, unsigned n) {
bool r = root(a.m_num, b.m_num, n);
if (b.m_k % n == 0) {
a.m_k = b.m_k / n;
a.normalize();
return r;
}
else if (a.m_num.is_neg()) {
a.m_k = b.m_k / n;
a.m_k++;
a.normalize();
return false;
}
else {
a.m_k = b.m_k / n;
a.normalize();
return false;
}
}
void refine_upper(mpq const & q, mpbq & l, mpbq & u) {
lean_assert(l < q && q < u);
lean_assert(!q.get_denominator().is_power_of_two());
mpbq mid;
while (true) {
mid = l + u;
div2(mid);
if (mid > q) {
u.swap(mid);
lean_assert(l < q && q < u);
return;
}
l.swap(mid);
}
}
void refine_lower(mpq const & q, mpbq & l, mpbq & u) {
lean_assert(l < q && q < u);
lean_assert(!q.get_denominator().is_power_of_two());
mpbq mid;
while (true) {
mid = l + u;
div2(mid);
if (mid < q) {
l.swap(mid);
lean_assert(l < q && q < u);
return;
}
u.swap(mid);
}
}
bool lt_1div2k(mpbq const & a, unsigned k) {
if (a.m_num.is_nonpos())
return true;
if (a.m_k <= k) {
// since a.m_num >= 1
return false;
}
else {
lean_assert(a.m_k > k);
static thread_local mpz tmp;
tmp = 1;
mul2k(tmp, tmp, a.m_k - k);
return a.m_num < tmp;
}
}
std::ostream & operator<<(std::ostream & out, mpbq const & v) {
if (v.m_k == 0) {
out << v.m_num;
@ -205,6 +321,35 @@ std::ostream & operator<<(std::ostream & out, mpbq const & v) {
return out;
}
void display_decimal(std::ostream & out, mpbq const & a, unsigned prec) {
if (a.is_integer()) {
out << a.m_num;
return;
}
else {
mpz two_k;
mpz n1, v1;
if (a.is_neg())
out << "-";
v1 = abs(a.m_num);
power(two_k, mpz(2), a.m_k);
n1 = rem(v1, two_k);
v1 = v1/two_k;
lean_assert(!n1.is_zero());
out << v1;
out << ".";
for (unsigned i = 0; i < prec; i++) {
n1 *= 10;
v1 = n1/two_k;
n1 = rem(n1, two_k);
out << v1;
if (n1.is_zero())
return;
}
out << "?";
}
}
}
void pp(lean::mpbq const & n) { std::cout << n << std::endl; }

View file

@ -6,6 +6,7 @@ Author: Leonardo de Moura
*/
#pragma once
#include "mpz.h"
#include "mpq.h"
namespace lean {
@ -55,54 +56,67 @@ public:
friend int cmp(mpbq const & a, mpbq const & b);
friend int cmp(mpbq const & a, mpz const & b);
friend int cmp(mpbq const & a, mpq const & b);
friend int cmp(mpbq const & a, unsigned b) { return cmp(a, mpbq(b)); }
friend int cmp(mpbq const & a, int b) { return cmp(a, mpbq(b)); }
friend bool operator<(mpbq const & a, mpbq const & b) { return cmp(a, b) < 0; }
friend bool operator<(mpbq const & a, mpz const & b) { return cmp(a, b) < 0; }
friend bool operator<(mpbq const & a, mpq const & b) { return cmp(a, b) < 0; }
friend bool operator<(mpbq const & a, unsigned b) { return cmp(a, b) < 0; }
friend bool operator<(mpbq const & a, int b) { return cmp(a, b) < 0; }
friend bool operator<(mpz const & a, mpbq const & b) { return cmp(b, a) > 0; }
friend bool operator<(mpq const & a, mpbq const & b) { return cmp(b, a) > 0; }
friend bool operator<(unsigned a, mpbq const & b) { return cmp(b, a) > 0; }
friend bool operator<(int a, mpbq const & b) { return cmp(b, a) > 0; }
friend bool operator<(mpz const & a, mpbq const & b) { return cmp(b, a) > 0; }
friend bool operator>(mpbq const & a, mpbq const & b) { return cmp(a, b) > 0; }
friend bool operator>(mpbq const & a, mpz const & b) { return cmp(a, b) > 0; }
friend bool operator>(mpbq const & a, mpq const & b) { return cmp(a, b) > 0; }
friend bool operator>(mpbq const & a, unsigned b) { return cmp(a, b) > 0; }
friend bool operator>(mpbq const & a, int b) { return cmp(a, b) > 0; }
friend bool operator>(unsigned a, mpbq const & b) { return cmp(b, a) > 0; }
friend bool operator>(int a, mpbq const & b) { return cmp(b, a) > 0; }
friend bool operator>(mpz const & a, mpbq const & b) { return cmp(b, a) > 0; }
friend bool operator>(mpz const & a, mpbq const & b) { return cmp(b, a) < 0; }
friend bool operator>(mpq const & a, mpbq const & b) { return cmp(b, a) < 0; }
friend bool operator>(unsigned a, mpbq const & b) { return cmp(b, a) < 0; }
friend bool operator>(int a, mpbq const & b) { return cmp(b, a) < 0; }
friend bool operator<=(mpbq const & a, mpbq const & b) { return cmp(a, b) <= 0; }
friend bool operator<=(mpbq const & a, mpz const & b) { return cmp(a, b) <= 0; }
friend bool operator<=(mpbq const & a, mpq const & b) { return cmp(a, b) <= 0; }
friend bool operator<=(mpbq const & a, unsigned b) { return cmp(a, b) <= 0; }
friend bool operator<=(mpbq const & a, int b) { return cmp(a, b) <= 0; }
friend bool operator<=(unsigned a, mpbq const & b) { return cmp(b, a) > 0; }
friend bool operator<=(int a, mpbq const & b) { return cmp(b, a) > 0; }
friend bool operator<=(mpz const & a, mpbq const & b) { return cmp(b, a) > 0; }
friend bool operator<=(mpz const & a, mpbq const & b) { return cmp(b, a) >= 0; }
friend bool operator<=(mpq const & a, mpbq const & b) { return cmp(b, a) >= 0; }
friend bool operator<=(unsigned a, mpbq const & b) { return cmp(b, a) >= 0; }
friend bool operator<=(int a, mpbq const & b) { return cmp(b, a) >= 0; }
friend bool operator>=(mpbq const & a, mpbq const & b) { return cmp(a, b) >= 0; }
friend bool operator>=(mpbq const & a, mpz const & b) { return cmp(a, b) >= 0; }
friend bool operator>=(mpbq const & a, mpq const & b) { return cmp(a, b) >= 0; }
friend bool operator>=(mpbq const & a, unsigned b) { return cmp(a, b) >= 0; }
friend bool operator>=(mpbq const & a, int b) { return cmp(a, b) >= 0; }
friend bool operator>=(unsigned a, mpbq const & b) { return cmp(b, a) > 0; }
friend bool operator>=(int a, mpbq const & b) { return cmp(b, a) > 0; }
friend bool operator>=(mpz const & a, mpbq const & b) { return cmp(b, a) > 0; }
friend bool operator>=(mpz const & a, mpbq const & b) { return cmp(b, a) <= 0; }
friend bool operator>=(mpq const & a, mpbq const & b) { return cmp(b, a) <= 0; }
friend bool operator>=(unsigned a, mpbq const & b) { return cmp(b, a) <= 0; }
friend bool operator>=(int a, mpbq const & b) { return cmp(b, a) <= 0; }
friend bool operator==(mpbq const & a, mpbq const & b) { return a.m_k == b.m_k && a.m_num == b.m_num; }
friend bool operator==(mpbq const & a, mpz const & b) { return a.is_integer() && a.m_num == b; }
friend bool operator==(mpbq const & a, mpq const & b) { return cmp(a, b) == 0; }
friend bool operator==(mpbq const & a, unsigned int b) { return a.is_integer() && a.m_num == b; }
friend bool operator==(mpbq const & a, int b) { return a.is_integer() && a.m_num == b; }
friend bool operator==(mpz const & a, mpbq const & b) { return operator==(b, a); }
friend bool operator==(mpq const & a, mpbq const & b) { return operator==(b, a); }
friend bool operator==(unsigned int a, mpbq const & b) { return operator==(b, a); }
friend bool operator==(int a, mpbq const & b) { return operator==(b, a); }
friend bool operator!=(mpbq const & a, mpbq const & b) { return !operator==(a,b); }
friend bool operator!=(mpbq const & a, mpz const & b) { return !operator==(a,b); }
friend bool operator!=(mpz const & a, mpbq const & b) { return !operator==(a,b); }
friend bool operator!=(mpbq const & a, mpq const & b) { return !operator==(a,b); }
friend bool operator!=(mpbq const & a, unsigned int b) { return !operator==(a,b); }
friend bool operator!=(mpbq const & a, int b) { return !operator==(a,b); }
friend bool operator!=(mpz const & a, mpbq const & b) { return !operator==(a,b); }
friend bool operator!=(mpq const & a, mpbq const & b) { return !operator==(a,b); }
friend bool operator!=(unsigned int a, mpbq const & b) { return !operator==(a,b); }
friend bool operator!=(int a, mpbq const & b) { return !operator==(a,b); }
@ -148,6 +162,8 @@ public:
mpbq & operator--() { return operator-=(1); }
mpbq operator--(int) { mpbq r(*this); --(*this); return r; }
friend void power(mpbq & a, mpbq const & b, unsigned k);
/**
\brief Return the magnitude of a = b/2^k.
It is defined as:
@ -178,14 +194,53 @@ public:
friend void mul2(mpbq & a);
// a <- a*2^k
friend void mul2k(mpbq & a, unsigned k);
// a <- b * 2^k
friend void mul2k(mpbq & a, mpbq const & b, unsigned k) { a = b; mul2k(a, k); }
// a <- b / 2^k
friend void div2k(mpbq & a, mpbq const & b, unsigned k);
// a <- a/2
friend void div2(mpbq & a) { bool old_k_zero = (a.m_k == 0); a.m_k++; if (old_k_zero) a.normalize(); }
// a <- a/2^k
friend void div2k(mpbq & a, unsigned k) { bool old_k_zero = (a.m_k == 0); a.m_k += k; if (old_k_zero) a.normalize(); }
// a <- b / 2^k
friend void div2k(mpbq & a, mpbq const & b, unsigned k) { a = b; div2k(a, k); }
/**
\brief Return true if b^{1/n} is a binary rational, and store the result in a.
Otherwise, return false and return an lower bound based on the integer root of the
numerator and denominator/n
*/
friend bool root_lower(mpbq & a, mpbq const & b, unsigned n);
friend bool root_upper(mpbq & a, mpbq const & b, unsigned n);
/**
\brief Given a rational q which cannot be represented as a binary rational,
and an interval (l, u) s.t. l < q < u. This method stores in u, a u' s.t.
q < u' < u.
In the refinement process, the lower bound l may be also refined to l'
s.t. l < l' < q
*/
friend void refine_upper(mpq const & q, mpbq & l, mpbq & u);
/**
\brief Similar to refine_upper.
*/
friend void refine_lower(mpq const & q, mpbq & l, mpbq & u);
/**
\brief Return true iff a < 1/2^k
*/
friend bool lt_1div2k(mpbq const & a, unsigned k);
friend std::ostream & operator<<(std::ostream & out, mpbq const & v);
friend void display_decimal(std::ostream & out, mpbq const & a, unsigned prec);
class decimal {
mpbq const & m_val;
unsigned m_prec;
public:
decimal(mpbq const & val, unsigned prec = 10):m_val(val), m_prec(prec) {}
friend std::ostream & operator<<(std::ostream & out, decimal const & d) { display_decimal(out, d.m_val, d.m_prec); return out; }
};
};
}

View file

@ -8,6 +8,17 @@ Author: Leonardo de Moura
namespace lean {
int cmp(mpq const & a, mpz const & b) {
if (a.is_integer()) {
return mpz_cmp(mpq_numref(a.m_val), mpq::zval(b));
}
else {
static thread_local mpz tmp;
mpz_mul(mpq::zval(tmp), mpq_denref(a.m_val), mpq::zval(b));
return mpz_cmp(mpq_numref(a.m_val), mpq::zval(tmp));
}
}
void mpq::floor() {
if (is_integer())
return;
@ -62,6 +73,32 @@ std::ostream & operator<<(std::ostream & out, mpq const & v) {
return out;
}
void display_decimal(std::ostream & out, mpq const & a, unsigned prec) {
mpz n1, d1, v1;
numerator(n1, a);
denominator(d1, a);
if (a.is_neg()) {
out << "-";
neg(n1);
}
v1 = n1 / d1;
out << v1;
n1 = rem(n1, d1);
if (n1.is_zero())
return;
out << ".";
for (unsigned i = 0; i < prec; i++) {
n1 *= 10;
v1 = n1 / d1;
lean_assert(v1 < 10);
out << v1;
n1 = rem(n1, d1);
if (n1.is_zero())
return;
}
out << "?";
}
}
void pp(lean::mpq const & v) { std::cout << v << std::endl; }

View file

@ -9,7 +9,9 @@ Author: Leonardo de Moura
namespace lean {
// Wrapper for GMP rationals
/**
\brief Wrapper for GMP rationals
*/
class mpq {
mpq_t m_val;
static mpz_t const & zval(mpz const & v) { return v.m_val; }
@ -68,30 +70,39 @@ public:
bool is_integer() const { return mpz_cmp_ui(mpq_denref(m_val), 1u) == 0; }
friend int cmp(mpq const & a, mpq const & b) { return mpq_cmp(a.m_val, b.m_val); }
friend int cmp(mpq const & a, mpz const & b);
friend int cmp(mpq const & a, unsigned b) { return mpq_cmp_ui(a.m_val, b, 1); }
friend int cmp(mpq const & a, int b) { return mpq_cmp_si(a.m_val, b, 1); }
friend bool operator<(mpq const & a, mpq const & b) { return cmp(a, b) < 0; }
friend bool operator<(mpq const & a, mpz const & b) { return cmp(a, b) < 0; }
friend bool operator<(mpq const & a, unsigned b) { return cmp(a, b) < 0; }
friend bool operator<(mpq const & a, int b) { return cmp(a, b) < 0; }
friend bool operator<(mpz const & a, mpq const & b) { return cmp(b, a) > 0; }
friend bool operator<(unsigned a, mpq const & b) { return cmp(b, a) > 0; }
friend bool operator<(int a, mpq const & b) { return cmp(b, a) > 0; }
friend bool operator>(mpq const & a, mpq const & b) { return cmp(a, b) > 0; }
friend bool operator>(mpq const & a, mpz const & b) { return cmp(a, b) > 0; }
friend bool operator>(mpq const & a, unsigned b) { return cmp(a, b) > 0; }
friend bool operator>(mpq const & a, int b) { return cmp(a, b) > 0; }
friend bool operator>(mpz const & a, mpq const & b) { return cmp(b, a) < 0; }
friend bool operator>(unsigned a, mpq const & b) { return cmp(b, a) < 0; }
friend bool operator>(int a, mpq const & b) { return cmp(b, a) < 0; }
friend bool operator<=(mpq const & a, mpq const & b) { return cmp(a, b) <= 0; }
friend bool operator<=(mpq const & a, mpz const & b) { return cmp(a, b) <= 0; }
friend bool operator<=(mpq const & a, unsigned b) { return cmp(a, b) <= 0; }
friend bool operator<=(mpq const & a, int b) { return cmp(a, b) <= 0; }
friend bool operator<=(mpz const & a, mpq const & b) { return cmp(b, a) >= 0; }
friend bool operator<=(unsigned a, mpq const & b) { return cmp(b, a) >= 0; }
friend bool operator<=(int a, mpq const & b) { return cmp(b, a) >= 0; }
friend bool operator>=(mpq const & a, mpq const & b) { return cmp(a, b) >= 0; }
friend bool operator>=(mpq const & a, mpz const & b) { return cmp(a, b) >= 0; }
friend bool operator>=(mpq const & a, unsigned b) { return cmp(a, b) >= 0; }
friend bool operator>=(mpq const & a, int b) { return cmp(a, b) >= 0; }
friend bool operator>=(mpz const & a, mpq const & b) { return cmp(b, a) <= 0; }
friend bool operator>=(unsigned a, mpq const & b) { return cmp(b, a) <= 0; }
friend bool operator>=(int a, mpq const & b) { return cmp(b, a) <= 0; }
@ -105,9 +116,9 @@ public:
friend bool operator!=(mpq const & a, mpq const & b) { return !operator==(a,b); }
friend bool operator!=(mpq const & a, mpz const & b) { return !operator==(a,b); }
friend bool operator!=(mpz const & a, mpq const & b) { return !operator==(a,b); }
friend bool operator!=(mpq const & a, unsigned int b) { return !operator==(a,b); }
friend bool operator!=(mpq const & a, int b) { return !operator==(a,b); }
friend bool operator!=(mpz const & a, mpq const & b) { return !operator==(a,b); }
friend bool operator!=(unsigned int a, mpq const & b) { return !operator==(a,b); }
friend bool operator!=(int a, mpq const & b) { return !operator==(a,b); }
@ -169,16 +180,32 @@ public:
mpq & operator--() { return operator-=(1); }
mpq operator--(int) { mpq r(*this); --(*this); return r; }
mpz get_numerator() const { return mpz(mpq_numref(m_val)); }
mpz get_denominator() const { return mpz(mpq_denref(m_val)); }
friend std::ostream & operator<<(std::ostream & out, mpq const & v);
// a <- numerator(b)
friend void numerator(mpz & a, mpq const & b) { mpz_set(a.m_val, mpq_numref(b.m_val)); }
// a <- denominator(b)
friend void denominator(mpz & a, mpq const & b) { mpz_set(a.m_val, mpq_denref(b.m_val)); }
mpz get_numerator() const { mpz r; numerator(r, *this); return r; }
mpz get_denominator() const { mpz r; denominator(r, *this); return r; }
void floor();
friend mpz floor(mpq const & a);
void ceil();
friend mpz ceil(mpq const & a);
friend std::ostream & operator<<(std::ostream & out, mpq const & v);
friend void display_decimal(std::ostream & out, mpq const & a, unsigned prec);
class decimal {
mpq const & m_val;
unsigned m_prec;
public:
decimal(mpq const & val, unsigned prec = 10):m_val(val), m_prec(prec) {}
friend std::ostream & operator<<(std::ostream & out, decimal const & d) { display_decimal(out, d.m_val, d.m_prec); return out; }
};
};
}

View file

@ -52,6 +52,12 @@ mpz operator%(mpz const & a, mpz const & b) {
return r;
}
bool root(mpz & root, mpz const & a, unsigned k) {
static thread_local mpz rem;
mpz_rootrem(root.m_val, rem.m_val, a.m_val, k);
return rem.is_zero();
}
void display(std::ostream & out, __mpz_struct const * v) {
size_t sz = mpz_sizeinbase(v, 10) + 2;
if (sz < 1024) {

View file

@ -12,7 +12,9 @@ Author: Leonardo de Moura
namespace lean {
class mpq;
// Wrapper for GMP integers
/**
\brief Wrapper for GMP integers
*/
class mpz {
friend class mpq;
mpz_t m_val;
@ -198,10 +200,12 @@ public:
*/
unsigned power_of_two_multiple() const { return mpz_scan1(m_val, 0); }
friend mpz power(mpz const & a, unsigned k) { mpz r; mpz_pow_ui(r.m_val, a.m_val, k); return r; }
friend void power(mpz & a, mpz const & b, unsigned k) { mpz_pow_ui(a.m_val, b.m_val, k); }
friend mpz power(mpz const & a, unsigned k) { mpz r; power(r, a, k); return r; }
friend void rootrem(mpz & root, mpz & rem, mpz const & a, unsigned k) { mpz_rootrem(root.m_val, rem.m_val, a.m_val, k); }
friend void root(mpz & root, mpz const & a, unsigned k) { mpz_root(root.m_val, a.m_val, k); }
// root <- a^{1/k}, return true iff the result is an integer
friend bool root(mpz & root, mpz const & a, unsigned k);
friend mpz root(mpz const & a, unsigned k) { mpz r; root(r, a, k); return r; }
friend void gcd(mpz & g, mpz const & a, mpz const & b) { mpz_gcd(g.m_val, a.m_val, b.m_val); }