Add missing operators to mpz, mpq, mpbq. Add pp functions for debugging
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
c0f9f06d70
commit
eaa76ee9d2
6 changed files with 232 additions and 20 deletions
|
@ -15,41 +15,196 @@ void mpbq::normalize() {
|
|||
m_k = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned k = m_num.power_of_two_multiple();
|
||||
if (k > m_k)
|
||||
k = m_k;
|
||||
m_num.div2k(k);
|
||||
div2k(m_num, m_num, k);
|
||||
m_k -= k;
|
||||
}
|
||||
|
||||
|
||||
int cmp(mpbq const & a, mpbq const & b) {
|
||||
static thread_local mpz tmp;
|
||||
if (a.m_k == b.m_k)
|
||||
return cmp(a.m_num, b.m_num);
|
||||
else if (a.m_k < b.m_k) {
|
||||
mpz tmp(a.m_num);
|
||||
tmp.mul2k(b.m_k - a.m_k);
|
||||
mul2k(tmp, a.m_num, b.m_k - a.m_k);
|
||||
return cmp(tmp, b.m_num);
|
||||
}
|
||||
else {
|
||||
lean_assert(a.m_k > b.m_k);
|
||||
mpz tmp(b.m_num);
|
||||
tmp.mul2k(a.m_k - b.m_k);
|
||||
mul2k(tmp, b.m_num, a.m_k - b.m_k);
|
||||
return cmp(a.m_num, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
int cmp(mpbq const & a, mpz const & b) {
|
||||
static thread_local mpz tmp;
|
||||
if (a.m_k == 0)
|
||||
return cmp(a.m_num, b);
|
||||
else {
|
||||
mpz tmp(b);
|
||||
tmp.mul2k(a.m_k);
|
||||
mul2k(tmp, b, a.m_k);
|
||||
return cmp(a.m_num, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
mpbq & mpbq::operator+=(mpbq const & a) {
|
||||
if (m_k == a.m_k) {
|
||||
m_num += a.m_num;
|
||||
}
|
||||
else if (m_k < a.m_k) {
|
||||
mul2k(m_num, m_num, a.m_k - m_k);
|
||||
m_k = a.m_k;
|
||||
m_num += a.m_num;
|
||||
}
|
||||
else {
|
||||
lean_assert(m_k > a.m_k);
|
||||
static thread_local mpz tmp;
|
||||
mul2k(tmp, a.m_num, m_k - a.m_k);
|
||||
m_num += tmp;
|
||||
}
|
||||
normalize();
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
mpbq & mpbq::add_int(T const & a) {
|
||||
if (m_k == 0) {
|
||||
m_num += a;
|
||||
}
|
||||
else {
|
||||
lean_assert(m_k > 0);
|
||||
static thread_local mpz tmp;
|
||||
tmp = a;
|
||||
mul2k(tmp, tmp, m_k);
|
||||
m_num += tmp;
|
||||
}
|
||||
normalize();
|
||||
return *this;
|
||||
}
|
||||
mpbq & mpbq::operator+=(unsigned a) { return add_int<unsigned>(a); }
|
||||
mpbq & mpbq::operator+=(int a) { return add_int<int>(a); }
|
||||
|
||||
mpbq & mpbq::operator-=(mpbq const & a) {
|
||||
if (m_k == a.m_k) {
|
||||
m_num -= a.m_num;
|
||||
}
|
||||
else if (m_k < a.m_k) {
|
||||
mul2k(m_num, m_num, a.m_k - m_k);
|
||||
m_k = a.m_k;
|
||||
m_num -= a.m_num;
|
||||
}
|
||||
else {
|
||||
lean_assert(m_k > a.m_k);
|
||||
static thread_local mpz tmp;
|
||||
mul2k(tmp, a.m_num, m_k - a.m_k);
|
||||
m_num -= tmp;
|
||||
}
|
||||
normalize();
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
mpbq & mpbq::sub_int(T const & a) {
|
||||
if (m_k == 0) {
|
||||
m_num -= a;
|
||||
}
|
||||
else {
|
||||
lean_assert(m_k > 0);
|
||||
static thread_local mpz tmp;
|
||||
tmp = a;
|
||||
mul2k(tmp, tmp, m_k);
|
||||
m_num -= tmp;
|
||||
}
|
||||
normalize();
|
||||
return *this;
|
||||
}
|
||||
mpbq & mpbq::operator-=(unsigned a) { return sub_int<unsigned>(a); }
|
||||
mpbq & mpbq::operator-=(int a) { return sub_int<int>(a); }
|
||||
|
||||
mpbq & mpbq::operator*=(mpbq const & a) {
|
||||
m_num *= a.m_num;
|
||||
if (m_k == 0 || a.m_k == 0) {
|
||||
m_k += a.m_k;
|
||||
normalize();
|
||||
}
|
||||
else {
|
||||
m_k += a.m_k;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
mpbq & mpbq::mul_int(T const & a) {
|
||||
m_num *= a;
|
||||
normalize();
|
||||
return *this;
|
||||
}
|
||||
mpbq & mpbq::operator*=(unsigned a) { return mul_int<unsigned>(a); }
|
||||
mpbq & mpbq::operator*=(int a) { return mul_int<int>(a); }
|
||||
|
||||
int mpbq::magnitude_lb() const {
|
||||
int s = m_num.sgn();
|
||||
if (s < 0) {
|
||||
return m_num.mlog2() - m_k + 1;
|
||||
}
|
||||
else if (s == 0) {
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
lean_assert(s > 0);
|
||||
return m_num.log2() - m_k;
|
||||
}
|
||||
}
|
||||
|
||||
int mpbq::magnitude_ub() const {
|
||||
int s = m_num.sgn();
|
||||
if (s < 0) {
|
||||
return m_num.mlog2() - m_k;
|
||||
}
|
||||
else if (s == 0) {
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
lean_assert(s > 0);
|
||||
return m_num.log2() - m_k + 1;
|
||||
}
|
||||
}
|
||||
|
||||
void mul2(mpbq & a) {
|
||||
if (a.m_k == 0) {
|
||||
mul2k(a.m_num, a.m_num, 1);
|
||||
}
|
||||
else {
|
||||
a.m_k--;
|
||||
}
|
||||
}
|
||||
|
||||
void mul2k(mpbq & a, unsigned k) {
|
||||
if (k == 0)
|
||||
return;
|
||||
if (a.m_k < k) {
|
||||
mul2k(a.m_num, a.m_num, k - a.m_k);
|
||||
a.m_k = 0;
|
||||
}
|
||||
else {
|
||||
lean_assert(a.m_k >= k);
|
||||
a.m_k -= k;
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream & operator<<(std::ostream & out, mpbq const & v) {
|
||||
if (v.m_k == 0) {
|
||||
out << v.m_num;
|
||||
}
|
||||
else if (v.m_k == 1) {
|
||||
out << v.m_num << "/2";
|
||||
}
|
||||
else {
|
||||
out << v.m_num << "/2^" << v.m_k;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void pp(lean::mpbq const & n) { std::cout << n << std::endl; }
|
||||
|
|
|
@ -14,6 +14,9 @@ class mpbq {
|
|||
mpz m_num;
|
||||
unsigned m_k;
|
||||
void normalize();
|
||||
template<typename T> mpbq & add_int(T const & a);
|
||||
template<typename T> mpbq & sub_int(T const & a);
|
||||
template<typename T> mpbq & mul_int(T const & a);
|
||||
public:
|
||||
mpbq():m_k(0) {}
|
||||
mpbq(mpbq const & v):m_num(v.m_num), m_k(v.m_k) {}
|
||||
|
@ -23,6 +26,11 @@ public:
|
|||
mpbq(int n, unsigned k):m_num(n), m_k(k) { normalize(); }
|
||||
~mpbq() {}
|
||||
|
||||
mpbq & operator=(mpbq const & v) { m_num = v.m_num; m_k = v.m_k; return *this; }
|
||||
mpbq & operator=(mpbq && v) { swap(v); return *this; }
|
||||
mpbq & operator=(unsigned int v) { m_num = v; m_k = 0; return *this; }
|
||||
mpbq & operator=(int v) { m_num = v; m_k = 0; return *this; }
|
||||
|
||||
void swap(mpbq & o) { m_num.swap(o.m_num); std::swap(m_k, o.m_k); }
|
||||
|
||||
unsigned hash() const { return m_num.hash(); }
|
||||
|
@ -139,6 +147,43 @@ public:
|
|||
|
||||
mpbq & operator--() { return operator-=(1); }
|
||||
mpbq operator--(int) { mpbq r(*this); --(*this); return r; }
|
||||
|
||||
/**
|
||||
\brief Return the magnitude of a = b/2^k.
|
||||
It is defined as:
|
||||
a == 0 -> 0
|
||||
a > 0 -> log2(b) - k Note that 2^{log2(b) - k} <= a <= 2^{log2(b) - k + 1}
|
||||
a < 0 -> mlog2(b) - k + 1 Note that -2^{mlog2(b) - k + 1} <= a <= -2^{mlog2(b) - k}
|
||||
|
||||
Remark: mlog2(b) = log2(-b)
|
||||
|
||||
Examples:
|
||||
|
||||
5/2^3 log2(5) - 3 = -1
|
||||
21/2^2 log2(21) - 2 = 2
|
||||
-3/2^4 log2(3) - 4 + 1 = -2
|
||||
*/
|
||||
int magnitude_lb() const;
|
||||
|
||||
/**
|
||||
\brief Similar to magnitude_lb
|
||||
|
||||
a == 0 -> 0
|
||||
a > 0 -> log2(b) - k + 1 a <= 2^{log2(b) - k + 1}
|
||||
a < 0 -> mlog2(b) - k a <= -2^{mlog2(b) - k}
|
||||
*/
|
||||
int magnitude_ub() const;
|
||||
|
||||
// a <- a*2
|
||||
friend void mul2(mpbq & a);
|
||||
// a <- a*2^k
|
||||
friend void mul2k(mpbq & a, unsigned k);
|
||||
|
||||
// a <- b * 2^k
|
||||
friend void mul2k(mpbq & a, mpbq const & b, unsigned k) { a = b; mul2k(a, k); }
|
||||
// a <- b / 2^k
|
||||
friend void div2k(mpbq & a, mpbq const & b, unsigned k);
|
||||
|
||||
|
||||
friend std::ostream & operator<<(std::ostream & out, mpbq const & v);
|
||||
};
|
||||
|
|
|
@ -62,6 +62,7 @@ std::ostream & operator<<(std::ostream & out, mpq const & v) {
|
|||
return out;
|
||||
}
|
||||
|
||||
void pp(mpq const & v) { std::cout << v << std::endl; }
|
||||
|
||||
}
|
||||
|
||||
void pp(lean::mpq const & v) { std::cout << v << std::endl; }
|
||||
|
||||
|
|
|
@ -15,8 +15,13 @@ class mpq {
|
|||
static mpz_t const & zval(mpz const & v) { return v.m_val; }
|
||||
static mpz_t & zval(mpz & v) { return v.m_val; }
|
||||
public:
|
||||
void swap(mpq & v) { mpq_swap(m_val, v.m_val); }
|
||||
void swap_numerator(mpz & v) { mpz_swap(mpq_numref(m_val), v.m_val); mpq_canonicalize(m_val); }
|
||||
void swap_denominator(mpz & v) { mpz_swap(mpq_denref(m_val), v.m_val); mpq_canonicalize(m_val); }
|
||||
|
||||
mpq & operator=(mpz const & v) { mpq_set_z(m_val, v.m_val); return *this; }
|
||||
mpq & operator=(mpq const & v) { mpq_set(m_val, v.m_val); return *this; }
|
||||
mpq & operator=(mpq && v) { swap(v); return *this; }
|
||||
mpq & operator=(char const * v) { mpq_set_str(m_val, v, 10); return *this; }
|
||||
mpq & operator=(unsigned long int v) { mpq_set_ui(m_val, v, 1u); return *this; }
|
||||
mpq & operator=(long int v) { mpq_set_si(m_val, v, 1); return *this; }
|
||||
|
@ -39,10 +44,6 @@ public:
|
|||
mpq(double v):mpq() { mpq_set_d(m_val, v); }
|
||||
~mpq() { mpq_clear(m_val); }
|
||||
|
||||
void swap(mpq & v) { mpq_swap(m_val, v.m_val); }
|
||||
void swap_numerator(mpz & v) { mpz_swap(mpq_numref(m_val), v.m_val); mpq_canonicalize(m_val); }
|
||||
void swap_denominator(mpz & v) { mpz_swap(mpq_denref(m_val), v.m_val); mpq_canonicalize(m_val); }
|
||||
|
||||
unsigned hash() const { return static_cast<unsigned>(mpz_get_si(mpq_numref(m_val))); }
|
||||
|
||||
int sgn() const { return mpq_sgn(m_val); }
|
||||
|
|
|
@ -72,3 +72,5 @@ std::ostream & operator<<(std::ostream & out, mpz const & v) {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
void pp(lean::mpz const & n) { std::cout << n << std::endl; }
|
||||
|
|
|
@ -59,6 +59,14 @@ public:
|
|||
unsigned long int get_unsigned_long_int() const { lean_assert(is_unsigned_long_int()); return mpz_get_ui(m_val); }
|
||||
unsigned int get_unsigned_int() const { lean_assert(is_unsigned_int()); return static_cast<unsigned>(get_unsigned_long_int()); }
|
||||
|
||||
mpz & operator=(mpz const & v) { mpz_set(m_val, v.m_val); return *this; }
|
||||
mpz & operator=(mpz && v) { swap(v); return *this; }
|
||||
mpz & operator=(char const * v) { mpz_set_str(m_val, v, 10); return *this; }
|
||||
mpz & operator=(unsigned long int v) { mpz_set_ui(m_val, v); return *this; }
|
||||
mpz & operator=(long int v) { mpz_set_si(m_val, v); return *this; }
|
||||
mpz & operator=(unsigned int v) { return operator=(static_cast<unsigned long int>(v)); }
|
||||
mpz & operator=(int v) { return operator=(static_cast<long int>(v)); }
|
||||
|
||||
friend int cmp(mpz const & a, mpz const & b) { return mpz_cmp(a.m_val, b.m_val); }
|
||||
friend int cmp(mpz const & a, unsigned b) { return mpz_cmp_ui(a.m_val, b); }
|
||||
friend int cmp(mpz const & a, int b) { return mpz_cmp_si(a.m_val, b); }
|
||||
|
@ -164,11 +172,11 @@ public:
|
|||
// this <- this - a*b
|
||||
void submul(mpz const & a, mpz const & b) { mpz_submul(m_val, a.m_val, b.m_val); }
|
||||
|
||||
// this <- this * 2^k
|
||||
void mul2k(unsigned k) { mpz_mul_2exp(m_val, m_val, k); }
|
||||
// this <- this / 2^k
|
||||
void div2k(unsigned k) { mpz_tdiv_q_2exp(m_val, m_val, k); }
|
||||
|
||||
// a <- b * 2^k
|
||||
friend void mul2k(mpz & a, mpz const & b, unsigned k) { mpz_mul_2exp(a.m_val, b.m_val, k); }
|
||||
// a <- b / 2^k
|
||||
friend void div2k(mpz & a, mpz const & b, unsigned k) { mpz_tdiv_q_2exp(a.m_val, b.m_val, k); }
|
||||
|
||||
/**
|
||||
\brief Return the position of the most significant bit.
|
||||
Return 0 if the number is negative
|
||||
|
|
Loading…
Reference in a new issue