Add wrapper from GMP mpz numbers

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-07-16 13:11:24 -07:00
parent 5a0801789b
commit 31563b95bd
3 changed files with 268 additions and 0 deletions

70
src/util/mpz.cpp Normal file
View file

@ -0,0 +1,70 @@
/*
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <memory>
#include "mpz.h"
namespace lean {
unsigned mpz::log2() const {
if (is_nonpos())
return 0;
unsigned r = mpz_sizeinbase(m_val, 2);
lean_assert(r > 0);
return r - 1;
}
unsigned mpz::mlog2() const {
if (is_nonneg())
return 0;
mpz * _this = const_cast<mpz*>(this);
_this->neg();
lean_assert(is_pos());
unsigned r = mpz_sizeinbase(m_val, 2);
_this->neg();
lean_assert(is_neg());
return r - 1;
}
bool mpz::is_power_of_two(unsigned & shift) const {
if (is_nonpos())
return false;
if (mpz_popcount(m_val) == 1) {
shift = log2();
return true;
}
else {
return false;
}
}
mpz operator%(mpz const & a, mpz const & b) {
mpz r(rem(a, b));
if (r.is_neg()) {
if (b.is_pos())
r += b;
else
r -= b;
}
return r;
}
std::ostream & operator<<(std::ostream & out, mpz const & v) {
size_t sz = mpz_sizeinbase(v.m_val, 10) + 2;
if (sz < 1024) {
char buffer[1024];
mpz_get_str(buffer, 10, v.m_val);
out << buffer;
}
else {
std::unique_ptr<char> buffer(new char[sz]);
mpz_get_str(buffer.get(), 10, v.m_val);
out << buffer.get();
}
return out;
}
}

128
src/util/mpz.h Normal file
View file

@ -0,0 +1,128 @@
/*
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include <iostream>
#include <gmp.h>
#include "num_macros.h"
#include "debug.h"
namespace lean {
// Wrapper for GMP integers
class mpz {
mpz_t m_val;
public:
mpz() { mpz_init(m_val); }
mpz(char const * v) { mpz_init_set_str(m_val, const_cast<char*>(v), 10); }
mpz(unsigned long int v) { mpz_init_set_ui(m_val, v); }
mpz(long int v) { mpz_init_set_si(m_val, v); }
mpz(unsigned int v) { mpz_init_set_ui(m_val, v); }
mpz(int v) { mpz_init_set_si(m_val, v); }
mpz(mpz const & s) { mpz_init_set(m_val, s.m_val); }
mpz(mpz && s):mpz() { mpz_swap(m_val, s.m_val); }
~mpz() { mpz_clear(m_val); }
int sgn() const { return mpz_sgn(m_val); }
DEFINE_SIGN_METHODS()
void neg() { mpz_neg(m_val, m_val); }
void abs() { mpz_abs(m_val, m_val); }
bool even() const { return mpz_even_p(m_val) != 0; }
bool odd() const { return !even(); }
void swap(mpz & o) { mpz_swap(m_val, o.m_val); }
unsigned hash() const { return static_cast<unsigned>(mpz_get_si(m_val)); }
bool is_int() const { return mpz_fits_sint_p(m_val) != 0; }
bool is_unsigned_int() const { return mpz_fits_uint_p(m_val) != 0; }
bool is_long_int() const { return mpz_fits_slong_p(m_val) != 0; }
bool is_unsigned_long_int() const { return mpz_fits_ulong_p(m_val) != 0; }
long int get_long_int() const { lean_assert(is_long_int()); return mpz_get_si(m_val); }
int get_int() const { lean_assert(is_int()); return static_cast<int>(get_long_int()); }
unsigned long int get_unsigned_long_int() const { lean_assert(is_unsigned_long_int()); return mpz_get_ui(m_val); }
unsigned int get_unsigned_int() const { lean_assert(is_unsigned_int()); return static_cast<unsigned>(get_unsigned_long_int()); }
friend mpz abs(mpz const & a) { mpz r; mpz_abs(r.m_val, a.m_val); return r; }
friend int cmp(mpz const & a, mpz const & b) { return mpz_cmp(a.m_val, b.m_val); }
friend int cmp(mpz const & a, unsigned b) { return mpz_cmp_ui(a.m_val, b); }
friend int cmp(mpz const & a, int b) { return mpz_cmp_si(a.m_val, b); }
DEFINE_ORDER_OPS(mpz)
mpz & operator+=(mpz const & o) { mpz_add(m_val, m_val, o.m_val); return *this; }
mpz & operator+=(unsigned u) { mpz_add_ui(m_val, m_val, u); return *this; }
mpz & operator+=(int u) { if (u >= 0) mpz_add_ui(m_val, m_val, u); else mpz_sub_ui(m_val, m_val, u); return *this; }
mpz & operator-=(mpz const & o) { mpz_sub(m_val, m_val, o.m_val); return *this; }
mpz & operator-=(unsigned u) { mpz_sub_ui(m_val, m_val, u); return *this; }
mpz & operator-=(int u) { if (u >= 0) mpz_sub_ui(m_val, m_val, u); else mpz_add_ui(m_val, m_val, u); return *this; }
mpz & operator*=(mpz const & o) { mpz_mul(m_val, m_val, o.m_val); return *this; }
mpz & operator*=(unsigned u) { mpz_mul_ui(m_val, m_val, u); return *this; }
mpz & operator*=(int u) { mpz_mul_si(m_val, m_val, u); return *this; }
mpz & operator/=(mpz const & o) { mpz_tdiv_q(m_val, m_val, o.m_val); return *this; }
mpz & operator/=(unsigned u) { mpz_tdiv_q_ui(m_val, m_val, u); return *this; }
friend mpz rem(mpz const & a, mpz const & b) { mpz r; mpz_tdiv_r(r.m_val, a.m_val, b.m_val); return r; }
mpz & operator%=(mpz const & o) { mpz r(*this % o); mpz_swap(m_val, r.m_val); return *this; }
mpz & operator&=(mpz const & o) { mpz_and(m_val, m_val, o.m_val); return *this; }
mpz & operator|=(mpz const & o) { mpz_ior(m_val, m_val, o.m_val); return *this; }
mpz & operator^=(mpz const & o) { mpz_xor(m_val, m_val, o.m_val); return *this; }
void comp() { mpz_com(m_val, m_val); }
// this <- this + a*b
void addmul(mpz const & a, mpz const & b) { mpz_addmul(m_val, a.m_val, b.m_val); }
// this <- this - a*b
void submul(mpz const & a, mpz const & b) { mpz_submul(m_val, a.m_val, b.m_val); }
// this <- this * 2^k
void mul2k(unsigned k) { mpz_mul_2exp(m_val, m_val, k); }
// this <- this / 2^k
void div2k(unsigned k) { mpz_tdiv_q_2exp(m_val, m_val, k); }
/**
\brief Return the position of the most significant bit.
Return 0 if the number is negative
*/
unsigned log2() const;
/**
\brief log2(-n)
Return 0 if the number is nonegative
*/
unsigned mlog2() const;
bool perfect_square() const { return mpz_perfect_square_p(m_val); }
bool is_power_of_two() const { return is_pos() && mpz_popcount(m_val) == 1; }
bool is_power_of_two(unsigned & shift) const;
friend mpz power(mpz const & a, unsigned k) { mpz r; mpz_pow_ui(r.m_val, a.m_val, k); return r; }
friend void rootrem(mpz & root, mpz & rem, mpz const & a, unsigned k) { mpz_rootrem(root.m_val, rem.m_val, a.m_val, k); }
friend void root(mpz & root, mpz const & a, unsigned k) { mpz_root(root.m_val, a.m_val, k); }
friend mpz root(mpz const & a, unsigned k) { mpz r; root(r, a, k); return r; }
friend mpz operator&(mpz a, mpz const & b) { return a &= b; }
friend mpz operator|(mpz a, mpz const & b) { return a |= b; }
friend mpz operator^(mpz a, mpz const & b) { return a ^= b; }
friend mpz operator~(mpz a) { a.comp(); return a; }
DEFINE_ARITH_OPS(mpz)
friend mpz operator%(mpz const & a, mpz const & b);
friend void gcd(mpz & g, mpz const & a, mpz const & b) { mpz_gcd(g.m_val, a.m_val, b.m_val); }
friend mpz gcd(mpz const & a, mpz const & b) { mpz r; gcd(r, a, b); return r; }
friend void gcdext(mpz & g, mpz & s, mpz & t, mpz const & a, mpz const & b) { mpz_gcdext(g.m_val, s.m_val, t.m_val, a.m_val, b.m_val); }
friend void lcm(mpz & l, mpz const & a, mpz const & b) { mpz_lcm(l.m_val, a.m_val, b.m_val); }
friend mpz lcm(mpz const & a, mpz const & b) { mpz l; lcm(l, a, b); return l; }
friend std::ostream & operator<<(std::ostream & out, mpz const & v);
};
}

70
src/util/num_macros.h Normal file
View file

@ -0,0 +1,70 @@
/*
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
// Macros for defining numeral classes
#define DEFINE_ORDER_OPS(T) \
friend bool operator==(T const & a, T const & b) { return cmp(a, b) == 0; } \
friend bool operator!=(T const & a, T const & b) { return cmp(a, b) != 0; } \
friend bool operator<(T const & a, T const & b) { return cmp(a, b) < 0; } \
friend bool operator>(T const & a, T const & b) { return cmp(a, b) > 0; } \
friend bool operator<=(T const & a, T const & b) { return cmp(a, b) <= 0; } \
friend bool operator>=(T const & a, T const & b) { return cmp(a, b) >= 0; } \
friend bool operator==(T const & a, unsigned b) { return cmp(a, b) == 0; } \
friend bool operator!=(T const & a, unsigned b) { return cmp(a, b) != 0; } \
friend bool operator<(T const & a, unsigned b) { return cmp(a, b) < 0; } \
friend bool operator>(T const & a, unsigned b) { return cmp(a, b) > 0; } \
friend bool operator<=(T const & a, unsigned b) { return cmp(a, b) <= 0; } \
friend bool operator>=(T const & a, unsigned b) { return cmp(a, b) >= 0; } \
friend bool operator==(T const & a, int b) { return cmp(a, b) == 0; } \
friend bool operator!=(T const & a, int b) { return cmp(a, b) != 0; } \
friend bool operator<(T const & a, int b) { return cmp(a, b) < 0; } \
friend bool operator>(T const & a, int b) { return cmp(a, b) > 0; } \
friend bool operator<=(T const & a, int b) { return cmp(a, b) <= 0; } \
friend bool operator>=(T const & a, int b) { return cmp(a, b) >= 0; } \
friend bool operator==(unsigned a, T const & b) { return cmp(b, a) == 0; } \
friend bool operator!=(unsigned a, T const & b) { return cmp(b, a) != 0; } \
friend bool operator<(unsigned a, T const & b) { return cmp(b, a) > 0; } \
friend bool operator>(unsigned a, T const & b) { return cmp(b, a) < 0; } \
friend bool operator<=(unsigned a, T const & b) { return cmp(b, a) >= 0; } \
friend bool operator>=(unsigned a, T const & b) { return cmp(b, a) <= 0; } \
friend bool operator==(int a, T const & b) { return cmp(b, a) == 0; } \
friend bool operator!=(int a, T const & b) { return cmp(b, a) != 0; } \
friend bool operator<(int a, T const & b) { return cmp(b, a) > 0; } \
friend bool operator>(int a, T const & b) { return cmp(b, a) < 0; } \
friend bool operator<=(int a, T const & b) { return cmp(b, a) >= 0; } \
friend bool operator>=(int a, T const & b) { return cmp(b, a) <= 0; }
#define DEFINE_SIGN_METHODS() \
bool is_pos() const { return sgn() > 0; } \
bool is_neg() const { return sgn() < 0; } \
bool is_zero() const { return sgn() == 0; } \
bool is_nonpos() const { return !is_pos(); } \
bool is_nonneg() const { return !is_neg(); }
#define DEFINE_ARITH_OPS(T) \
friend T operator+(T a, T const & b) { return a += b; } \
friend T operator-(T a, T const & b) { return a -= b; } \
friend T operator*(T a, T const & b) { return a *= b; } \
friend T operator/(T a, T const & b) { return a /= b; } \
friend T operator+(T a, unsigned b) { return a += b; } \
friend T operator-(T a, unsigned b) { return a -= b; } \
friend T operator*(T a, unsigned b) { return a *= b; } \
friend T operator/(T a, unsigned b) { return a /= b; } \
friend T operator+(T a, int b) { return a += b; } \
friend T operator-(T a, int b) { return a -= b; } \
friend T operator*(T a, int b) { return a *= b; } \
friend T operator/(T a, int b) { return a /= b; } \
friend T operator+(unsigned a, T b) { return b += a; } \
friend T operator-(unsigned a, T b) { b.neg(); return b += a; } \
friend T operator*(unsigned a, T b) { return b *= a; } \
friend T operator/(unsigned a, T const & b) { T r(a); return r /= b; } \
friend T operator+(int a, T b) { return b += a; } \
friend T operator-(int a, T b) { b.neg(); return b += a; } \
friend T operator*(int a, T b) { return b *= a; } \
friend T operator/(int a, T const & b) { T r(a); return r /= b; }