Add missing operators to mpz, mpq, mpbq. Add pp functions for debugging

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-07-17 12:43:05 -07:00
parent c0f9f06d70
commit eaa76ee9d2
6 changed files with 232 additions and 20 deletions

View file

@ -15,41 +15,196 @@ void mpbq::normalize() {
m_k = 0;
return;
}
unsigned k = m_num.power_of_two_multiple();
if (k > m_k)
k = m_k;
m_num.div2k(k);
div2k(m_num, m_num, k);
m_k -= k;
}
int cmp(mpbq const & a, mpbq const & b) {
static thread_local mpz tmp;
if (a.m_k == b.m_k)
return cmp(a.m_num, b.m_num);
else if (a.m_k < b.m_k) {
mpz tmp(a.m_num);
tmp.mul2k(b.m_k - a.m_k);
mul2k(tmp, a.m_num, b.m_k - a.m_k);
return cmp(tmp, b.m_num);
}
else {
lean_assert(a.m_k > b.m_k);
mpz tmp(b.m_num);
tmp.mul2k(a.m_k - b.m_k);
mul2k(tmp, b.m_num, a.m_k - b.m_k);
return cmp(a.m_num, tmp);
}
}
int cmp(mpbq const & a, mpz const & b) {
static thread_local mpz tmp;
if (a.m_k == 0)
return cmp(a.m_num, b);
else {
mpz tmp(b);
tmp.mul2k(a.m_k);
mul2k(tmp, b, a.m_k);
return cmp(a.m_num, tmp);
}
}
mpbq & mpbq::operator+=(mpbq const & a) {
if (m_k == a.m_k) {
m_num += a.m_num;
}
else if (m_k < a.m_k) {
mul2k(m_num, m_num, a.m_k - m_k);
m_k = a.m_k;
m_num += a.m_num;
}
else {
lean_assert(m_k > a.m_k);
static thread_local mpz tmp;
mul2k(tmp, a.m_num, m_k - a.m_k);
m_num += tmp;
}
normalize();
return *this;
}
template<typename T>
mpbq & mpbq::add_int(T const & a) {
if (m_k == 0) {
m_num += a;
}
else {
lean_assert(m_k > 0);
static thread_local mpz tmp;
tmp = a;
mul2k(tmp, tmp, m_k);
m_num += tmp;
}
normalize();
return *this;
}
mpbq & mpbq::operator+=(unsigned a) { return add_int<unsigned>(a); }
mpbq & mpbq::operator+=(int a) { return add_int<int>(a); }
mpbq & mpbq::operator-=(mpbq const & a) {
if (m_k == a.m_k) {
m_num -= a.m_num;
}
else if (m_k < a.m_k) {
mul2k(m_num, m_num, a.m_k - m_k);
m_k = a.m_k;
m_num -= a.m_num;
}
else {
lean_assert(m_k > a.m_k);
static thread_local mpz tmp;
mul2k(tmp, a.m_num, m_k - a.m_k);
m_num -= tmp;
}
normalize();
return *this;
}
template<typename T>
mpbq & mpbq::sub_int(T const & a) {
if (m_k == 0) {
m_num -= a;
}
else {
lean_assert(m_k > 0);
static thread_local mpz tmp;
tmp = a;
mul2k(tmp, tmp, m_k);
m_num -= tmp;
}
normalize();
return *this;
}
mpbq & mpbq::operator-=(unsigned a) { return sub_int<unsigned>(a); }
mpbq & mpbq::operator-=(int a) { return sub_int<int>(a); }
mpbq & mpbq::operator*=(mpbq const & a) {
m_num *= a.m_num;
if (m_k == 0 || a.m_k == 0) {
m_k += a.m_k;
normalize();
}
else {
m_k += a.m_k;
}
return *this;
}
template<typename T>
mpbq & mpbq::mul_int(T const & a) {
m_num *= a;
normalize();
return *this;
}
mpbq & mpbq::operator*=(unsigned a) { return mul_int<unsigned>(a); }
mpbq & mpbq::operator*=(int a) { return mul_int<int>(a); }
int mpbq::magnitude_lb() const {
int s = m_num.sgn();
if (s < 0) {
return m_num.mlog2() - m_k + 1;
}
else if (s == 0) {
return 0;
}
else {
lean_assert(s > 0);
return m_num.log2() - m_k;
}
}
int mpbq::magnitude_ub() const {
int s = m_num.sgn();
if (s < 0) {
return m_num.mlog2() - m_k;
}
else if (s == 0) {
return 0;
}
else {
lean_assert(s > 0);
return m_num.log2() - m_k + 1;
}
}
void mul2(mpbq & a) {
if (a.m_k == 0) {
mul2k(a.m_num, a.m_num, 1);
}
else {
a.m_k--;
}
}
void mul2k(mpbq & a, unsigned k) {
if (k == 0)
return;
if (a.m_k < k) {
mul2k(a.m_num, a.m_num, k - a.m_k);
a.m_k = 0;
}
else {
lean_assert(a.m_k >= k);
a.m_k -= k;
}
}
std::ostream & operator<<(std::ostream & out, mpbq const & v) {
if (v.m_k == 0) {
out << v.m_num;
}
else if (v.m_k == 1) {
out << v.m_num << "/2";
}
else {
out << v.m_num << "/2^" << v.m_k;
}
return out;
}
}
void pp(lean::mpbq const & n) { std::cout << n << std::endl; }

View file

@ -14,6 +14,9 @@ class mpbq {
mpz m_num;
unsigned m_k;
void normalize();
template<typename T> mpbq & add_int(T const & a);
template<typename T> mpbq & sub_int(T const & a);
template<typename T> mpbq & mul_int(T const & a);
public:
mpbq():m_k(0) {}
mpbq(mpbq const & v):m_num(v.m_num), m_k(v.m_k) {}
@ -23,6 +26,11 @@ public:
mpbq(int n, unsigned k):m_num(n), m_k(k) { normalize(); }
~mpbq() {}
mpbq & operator=(mpbq const & v) { m_num = v.m_num; m_k = v.m_k; return *this; }
mpbq & operator=(mpbq && v) { swap(v); return *this; }
mpbq & operator=(unsigned int v) { m_num = v; m_k = 0; return *this; }
mpbq & operator=(int v) { m_num = v; m_k = 0; return *this; }
void swap(mpbq & o) { m_num.swap(o.m_num); std::swap(m_k, o.m_k); }
unsigned hash() const { return m_num.hash(); }
@ -139,6 +147,43 @@ public:
mpbq & operator--() { return operator-=(1); }
mpbq operator--(int) { mpbq r(*this); --(*this); return r; }
/**
\brief Return the magnitude of a = b/2^k.
It is defined as:
a == 0 -> 0
a > 0 -> log2(b) - k Note that 2^{log2(b) - k} <= a <= 2^{log2(b) - k + 1}
a < 0 -> mlog2(b) - k + 1 Note that -2^{mlog2(b) - k + 1} <= a <= -2^{mlog2(b) - k}
Remark: mlog2(b) = log2(-b)
Examples:
5/2^3 log2(5) - 3 = -1
21/2^2 log2(21) - 2 = 2
-3/2^4 log2(3) - 4 + 1 = -2
*/
int magnitude_lb() const;
/**
\brief Similar to magnitude_lb
a == 0 -> 0
a > 0 -> log2(b) - k + 1 a <= 2^{log2(b) - k + 1}
a < 0 -> mlog2(b) - k a <= -2^{mlog2(b) - k}
*/
int magnitude_ub() const;
// a <- a*2
friend void mul2(mpbq & a);
// a <- a*2^k
friend void mul2k(mpbq & a, unsigned k);
// a <- b * 2^k
friend void mul2k(mpbq & a, mpbq const & b, unsigned k) { a = b; mul2k(a, k); }
// a <- b / 2^k
friend void div2k(mpbq & a, mpbq const & b, unsigned k);
friend std::ostream & operator<<(std::ostream & out, mpbq const & v);
};

View file

@ -62,6 +62,7 @@ std::ostream & operator<<(std::ostream & out, mpq const & v) {
return out;
}
void pp(mpq const & v) { std::cout << v << std::endl; }
}
void pp(lean::mpq const & v) { std::cout << v << std::endl; }

View file

@ -15,8 +15,13 @@ class mpq {
static mpz_t const & zval(mpz const & v) { return v.m_val; }
static mpz_t & zval(mpz & v) { return v.m_val; }
public:
void swap(mpq & v) { mpq_swap(m_val, v.m_val); }
void swap_numerator(mpz & v) { mpz_swap(mpq_numref(m_val), v.m_val); mpq_canonicalize(m_val); }
void swap_denominator(mpz & v) { mpz_swap(mpq_denref(m_val), v.m_val); mpq_canonicalize(m_val); }
mpq & operator=(mpz const & v) { mpq_set_z(m_val, v.m_val); return *this; }
mpq & operator=(mpq const & v) { mpq_set(m_val, v.m_val); return *this; }
mpq & operator=(mpq && v) { swap(v); return *this; }
mpq & operator=(char const * v) { mpq_set_str(m_val, v, 10); return *this; }
mpq & operator=(unsigned long int v) { mpq_set_ui(m_val, v, 1u); return *this; }
mpq & operator=(long int v) { mpq_set_si(m_val, v, 1); return *this; }
@ -39,10 +44,6 @@ public:
mpq(double v):mpq() { mpq_set_d(m_val, v); }
~mpq() { mpq_clear(m_val); }
void swap(mpq & v) { mpq_swap(m_val, v.m_val); }
void swap_numerator(mpz & v) { mpz_swap(mpq_numref(m_val), v.m_val); mpq_canonicalize(m_val); }
void swap_denominator(mpz & v) { mpz_swap(mpq_denref(m_val), v.m_val); mpq_canonicalize(m_val); }
unsigned hash() const { return static_cast<unsigned>(mpz_get_si(mpq_numref(m_val))); }
int sgn() const { return mpq_sgn(m_val); }

View file

@ -72,3 +72,5 @@ std::ostream & operator<<(std::ostream & out, mpz const & v) {
}
}
void pp(lean::mpz const & n) { std::cout << n << std::endl; }

View file

@ -59,6 +59,14 @@ public:
unsigned long int get_unsigned_long_int() const { lean_assert(is_unsigned_long_int()); return mpz_get_ui(m_val); }
unsigned int get_unsigned_int() const { lean_assert(is_unsigned_int()); return static_cast<unsigned>(get_unsigned_long_int()); }
mpz & operator=(mpz const & v) { mpz_set(m_val, v.m_val); return *this; }
mpz & operator=(mpz && v) { swap(v); return *this; }
mpz & operator=(char const * v) { mpz_set_str(m_val, v, 10); return *this; }
mpz & operator=(unsigned long int v) { mpz_set_ui(m_val, v); return *this; }
mpz & operator=(long int v) { mpz_set_si(m_val, v); return *this; }
mpz & operator=(unsigned int v) { return operator=(static_cast<unsigned long int>(v)); }
mpz & operator=(int v) { return operator=(static_cast<long int>(v)); }
friend int cmp(mpz const & a, mpz const & b) { return mpz_cmp(a.m_val, b.m_val); }
friend int cmp(mpz const & a, unsigned b) { return mpz_cmp_ui(a.m_val, b); }
friend int cmp(mpz const & a, int b) { return mpz_cmp_si(a.m_val, b); }
@ -164,11 +172,11 @@ public:
// this <- this - a*b
void submul(mpz const & a, mpz const & b) { mpz_submul(m_val, a.m_val, b.m_val); }
// this <- this * 2^k
void mul2k(unsigned k) { mpz_mul_2exp(m_val, m_val, k); }
// this <- this / 2^k
void div2k(unsigned k) { mpz_tdiv_q_2exp(m_val, m_val, k); }
// a <- b * 2^k
friend void mul2k(mpz & a, mpz const & b, unsigned k) { mpz_mul_2exp(a.m_val, b.m_val, k); }
// a <- b / 2^k
friend void div2k(mpz & a, mpz const & b, unsigned k) { mpz_tdiv_q_2exp(a.m_val, b.m_val, k); }
/**
\brief Return the position of the most significant bit.
Return 0 if the number is negative