From b8315e559366db15ace61ad74ea37c92140ffa01 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 21 Jul 2013 14:25:56 -0700 Subject: [PATCH] Fix ambiguous overloads. Improve == test for sexprs. Remove redundant code Signed-off-by: Leonardo de Moura --- src/numerics/mpbq.h | 34 +++++++++++++------------------ src/numerics/mpq.h | 9 ++------- src/numerics/mpz.h | 10 +++++----- src/sexpr/sexpr.cpp | 10 ++++++++-- src/sexpr/sexpr.h | 41 ++++++++++++++++++-------------------- src/tests/numerics/mpq.cpp | 37 ++++++++++++++++++++++++++++++++++ src/tests/sexpr/sexpr.cpp | 38 +++++++++++++++++++++++++++++++++++ src/util/name.h | 4 ++-- 8 files changed, 125 insertions(+), 58 deletions(-) diff --git a/src/numerics/mpbq.h b/src/numerics/mpbq.h index ce6649a4e..e2d94b75c 100644 --- a/src/numerics/mpbq.h +++ b/src/numerics/mpbq.h @@ -22,8 +22,8 @@ public: mpbq():m_k(0) {} mpbq(mpbq const & v):m_num(v.m_num), m_k(v.m_k) {} mpbq(mpbq && v); - mpbq(mpz const & v):m_num(v), m_k(0) {} - mpbq(int n):m_num(n), m_k(0) {} + explicit mpbq(mpz const & v):m_num(v), m_k(0) {} + explicit mpbq(int n):m_num(n), m_k(0) {} mpbq(int n, unsigned k):m_num(n), m_k(k) { normalize(); } ~mpbq() {} @@ -133,28 +133,22 @@ public: mpbq & operator*=(int a); friend mpbq operator+(mpbq a, mpbq const & b) { return a += b; } - friend mpbq operator+(mpbq a, mpz const & b) { return a += b; } - friend mpbq operator+(mpbq a, unsigned b) { return a += b; } - friend mpbq operator+(mpbq a, int b) { return a += b; } - friend mpbq operator+(mpz const & a, mpbq b) { return b += a; } - friend mpbq operator+(unsigned a, mpbq b) { return b += a; } - friend mpbq operator+(int a, mpbq b) { return b += a; } + template + friend mpbq operator+(mpbq a, T const & b) { return a += mpbq(b); } + template + friend mpbq operator+(T const & a, mpbq b) { return b += mpbq(a); } friend mpbq operator-(mpbq a, mpbq const & b) { return a -= b; } - friend mpbq operator-(mpbq a, mpz const & b) { return a -= b; } - friend mpbq operator-(mpbq a, unsigned b) { return a -= b; } - friend mpbq operator-(mpbq a, int b) { return a -= b; } - friend mpbq operator-(mpz const & a, mpbq b) { b.neg(); return b += a; } - friend mpbq operator-(unsigned a, mpbq b) { b.neg(); return b += a; } - friend mpbq operator-(int a, mpbq b) { b.neg(); return b += a; } + template + friend mpbq operator-(mpbq a, T const & b) { return a -= mbpq(b); } + template + friend mpbq operator-(T const & a, mpbq b) { b.neg(); return b += mpbq(a); } friend mpbq operator*(mpbq a, mpbq const & b) { return a *= b; } - friend mpbq operator*(mpbq a, mpz const & b) { return a *= b; } - friend mpbq operator*(mpbq a, unsigned b) { return a *= b; } - friend mpbq operator*(mpbq a, int b) { return a *= b; } - friend mpbq operator*(mpz const & a, mpbq b) { return b *= a; } - friend mpbq operator*(unsigned a, mpbq b) { return b *= a; } - friend mpbq operator*(int a, mpbq b) { return b *= a; } + template + friend mpbq operator*(mpbq a, mpz const & b) { return a *= mpbq(b); } + template + friend mpbq operator*(T const & a, mpbq b) { return b *= mpbq(a); } mpbq & operator++() { return operator+=(1); } mpbq operator++(int) { mpbq r(*this); ++(*this); return r; } diff --git a/src/numerics/mpq.h b/src/numerics/mpq.h index 00167346d..893b3b171 100644 --- a/src/numerics/mpq.h +++ b/src/numerics/mpq.h @@ -29,21 +29,16 @@ public: mpq & operator=(long int v) { mpq_set_si(m_val, v, 1); return *this; } mpq & operator=(unsigned int v) { return operator=(static_cast(v)); } mpq & operator=(int v) { return operator=(static_cast(v)); } + mpq & operator=(double v) { mpq_set_d(m_val, v); return *this; } mpq() { mpq_init(m_val); } mpq(mpq const & v):mpq() { operator=(v); } - mpq(mpz const & v):mpq() { operator=(v); } mpq(mpq && s):mpq() { mpq_swap(m_val, s.m_val); } - mpq(char const * v):mpq() { operator=(v); } - mpq(unsigned long int v):mpq() { operator=(v); } - mpq(long int v):mpq() { operator=(v); } - mpq(unsigned int v):mpq() { operator=(v); } - mpq(int v):mpq() { operator=(v); } + template explicit mpq(T const & v):mpq() { operator=(v); } mpq(unsigned long int n, unsigned long int d):mpq() { mpq_set_ui(m_val, n, d); mpq_canonicalize(m_val); } mpq(long int n, long int d):mpq() { mpq_set_si(m_val, n, d); mpq_canonicalize(m_val); } mpq(unsigned int n, unsigned int d):mpq() { mpq_set_ui(m_val, n, d); mpq_canonicalize(m_val); } mpq(int n, int d):mpq() { mpq_set_si(m_val, n, d); mpq_canonicalize(m_val); } - mpq(double v):mpq() { mpq_set_d(m_val, v); } ~mpq() { mpq_clear(m_val); } unsigned hash() const { return static_cast(mpz_get_si(mpq_numref(m_val))); } diff --git a/src/numerics/mpz.h b/src/numerics/mpz.h index 3d7657c06..87aa3d420 100644 --- a/src/numerics/mpz.h +++ b/src/numerics/mpz.h @@ -22,11 +22,11 @@ class mpz { mpz(__mpz_struct const * v) { mpz_init_set(m_val, v); } public: mpz() { mpz_init(m_val); } - mpz(char const * v) { mpz_init_set_str(m_val, const_cast(v), 10); } - mpz(unsigned long int v) { mpz_init_set_ui(m_val, v); } - mpz(long int v) { mpz_init_set_si(m_val, v); } - mpz(unsigned int v) { mpz_init_set_ui(m_val, v); } - mpz(int v) { mpz_init_set_si(m_val, v); } + explicit mpz(char const * v) { mpz_init_set_str(m_val, const_cast(v), 10); } + explicit mpz(unsigned long int v) { mpz_init_set_ui(m_val, v); } + explicit mpz(long int v) { mpz_init_set_si(m_val, v); } + explicit mpz(unsigned int v) { mpz_init_set_ui(m_val, v); } + explicit mpz(int v) { mpz_init_set_si(m_val, v); } mpz(mpz const & s) { mpz_init_set(m_val, s.m_val); } mpz(mpz && s):mpz() { mpz_swap(m_val, s.m_val); } ~mpz() { mpz_clear(m_val); } diff --git a/src/sexpr/sexpr.cpp b/src/sexpr/sexpr.cpp index b3a23f689..7a0cc7ce8 100644 --- a/src/sexpr/sexpr.cpp +++ b/src/sexpr/sexpr.cpp @@ -172,8 +172,8 @@ bool operator==(sexpr const & a, sexpr const & b) { sexpr::kind kb = b.get_kind(); if (ka != kb) return false; - // if (a.hash() != b.hash()) - // return false; + if (a.hash() != b.hash()) + return false; switch (ka) { case sexpr::kind::NIL: return true; case sexpr::kind::STRING: return to_string(a) == to_string(b); @@ -224,4 +224,10 @@ std::ostream & operator<<(std::ostream & out, sexpr const & s) { return out; } +bool operator==(sexpr const & a, name const & b) { return is_name(a) && to_name(a) == b; } +bool operator==(sexpr const & a, mpz const & b) { return is_mpz(a) && to_mpz(a) == b; } +bool operator==(sexpr const & a, mpq const & b) { return is_mpq(a) && to_mpq(a) == b; } + } + +void pp(lean::sexpr const & n) { std::cout << n << "\n"; } diff --git a/src/sexpr/sexpr.h b/src/sexpr/sexpr.h index c5f1ef63a..ee8007f3e 100644 --- a/src/sexpr/sexpr.h +++ b/src/sexpr/sexpr.h @@ -27,21 +27,18 @@ class sexpr { public: enum class kind { NIL, STRING, INT, DOUBLE, NAME, MPZ, MPQ, CONS }; sexpr():m_ptr(nullptr) {} - sexpr(char const * v); - sexpr(std::string const & v); - sexpr(int v); - sexpr(double v); - sexpr(name const & v); - sexpr(mpz const & v); - sexpr(mpq const & v); + explicit sexpr(char const * v); + explicit sexpr(std::string const & v); + explicit sexpr(int v); + explicit sexpr(double v); + explicit sexpr(name const & v); + explicit sexpr(mpz const & v); + explicit sexpr(mpq const & v); sexpr(sexpr const & h, sexpr const & t); - sexpr(char const * v, sexpr const & t):sexpr(sexpr(v), t) {} - sexpr(std::string const & v, sexpr const & t):sexpr(sexpr(v), t) {} - sexpr(int v, sexpr const & t):sexpr(sexpr(v), t) {} - sexpr(double v, sexpr const & t):sexpr(sexpr(v), t) {} - sexpr(name const & v, sexpr const & t):sexpr(sexpr(v), t) {} - sexpr(mpz const & v, sexpr const & t):sexpr(sexpr(v), t) {} - sexpr(mpq const & v, sexpr const & t):sexpr(sexpr(v), t) {} + template + sexpr(T const & h, sexpr const & t):sexpr(sexpr(h), t) {} + template + sexpr(T1 const & h, T2 const & t):sexpr(sexpr(h), sexpr(t)) {} sexpr(sexpr const & s); sexpr(sexpr && s); template @@ -71,11 +68,8 @@ public: sexpr & operator=(sexpr const & s); sexpr & operator=(sexpr&& s); - sexpr & operator=(char const * v) { return operator=(sexpr(v)); } - sexpr & operator=(std::string const & v) { return operator=(sexpr(v)); } - sexpr & operator=(int v) { return operator=(sexpr(v)); } - sexpr & operator=(mpz const & v) { return operator=(sexpr(v)); } - sexpr & operator=(mpq const & v) { return operator=(sexpr(v)); } + template + sexpr & operator=(T const & v) { return operator=(sexpr(v)); } friend void swap(sexpr & a, sexpr & b) { std::swap(a.m_ptr, b.m_ptr); } @@ -115,10 +109,13 @@ bool operator==(sexpr const & a, sexpr const & b); inline bool operator==(sexpr const & a, int b) { return is_int(a) && to_int(a) == b; } inline bool operator==(sexpr const & a, double b) { return is_double(a) && to_double(a) == b; } inline bool operator==(sexpr const & a, std::string const & b) { return is_string(a) && to_string(a) == b; } -inline bool operator==(sexpr const & a, name const & b) { return is_name(a) && to_name(a) == b; } -inline bool operator==(sexpr const & a, mpz const & b) { return is_mpz(a) && to_mpz(a) == b; } -inline bool operator==(sexpr const & a, mpq const & b) { return is_mpq(a) && to_mpq(a) == b; } +bool operator==(sexpr const & a, name const & b); +bool operator==(sexpr const & a, mpz const & b); +bool operator==(sexpr const & a, mpq const & b); +template inline bool operator==(T const & a, sexpr const & b) { return b == a; } +inline bool operator!=(sexpr const & a, sexpr const & b) { return !(a == b); } template inline bool operator!=(sexpr const & a, T const & b) { return !(a == b); } +template inline bool operator!=(T const & a, sexpr const & b) { return !(a == b); } bool operator<(sexpr const & a, sexpr const & b); inline bool operator>(sexpr const & a, sexpr const & b) { return b < a; } inline bool operator<=(sexpr const & a, sexpr const & b) { return !(a > b); } diff --git a/src/tests/numerics/mpq.cpp b/src/tests/numerics/mpq.cpp index 4bc73df0f..69dee8194 100644 --- a/src/tests/numerics/mpq.cpp +++ b/src/tests/numerics/mpq.cpp @@ -91,11 +91,48 @@ static void tst4() { lean_assert(a == mpq(2,5)); } +static void tst5() { + mpq a; + a = 1; + lean_assert(a == mpq(1)); + lean_assert(a <= 1); + lean_assert(a < 2); + lean_assert(a > 0); + lean_assert(a >= 0); + lean_assert(a >= 1); + lean_assert(!(a >= 2)); + lean_assert(a == 1); + lean_assert(1 == a); + lean_assert(a != 2); + lean_assert(!(a == 3)); + lean_assert(a < mpz(2)); + lean_assert(a <= mpz(1)); + lean_assert(a > 0); + lean_assert(a <= 1u); + lean_assert(a < 2u); + lean_assert(a > 0u); + lean_assert(a >= 1u); + lean_assert(a == 1u); + lean_assert(1u >= a); + lean_assert(2u > a); + a = "1/3"; + lean_assert(a == mpq(1,3)); + a = 2.0; + lean_assert(a == mpq(2)); + a = mpz(10); + lean_assert(a == mpq(10)); + lean_assert(a >= mpz(10)); + lean_assert(mpz(10) <= a); + lean_assert(mpz(10) >= a); + lean_assert(mpz(10) == a); +} + int main() { continue_on_violation(true); tst1(); tst2(); tst3(); tst4(); + tst5(); return has_violations() ? 1 : 0; } diff --git a/src/tests/sexpr/sexpr.cpp b/src/tests/sexpr/sexpr.cpp index cf94d24fd..1db013c30 100644 --- a/src/tests/sexpr/sexpr.cpp +++ b/src/tests/sexpr/sexpr.cpp @@ -71,8 +71,46 @@ static void tst1() { lean_assert(!contains(sexpr{10,20,-2,0,10}, [](sexpr e) { return to_int(e) < -10; })); } +void tst2() { + sexpr a; + a = 2; + lean_assert(a == sexpr(2)); + lean_assert(a == 2); + lean_assert(2 == a); + a = 0.125; + lean_assert(a == sexpr(0.125)); + lean_assert(a == 0.125); + lean_assert(0.125 == a); + a = "foo"; + lean_assert(a == sexpr("foo")); + lean_assert(a == "foo"); + lean_assert("foo" == a); + lean_assert(a != "blah"); + lean_assert(a != name("foo")); + lean_assert(std::string("foo") == a); + lean_assert(a == std::string("foo")); + a = name(name("foo"), 1); + lean_assert(a == sexpr(name(name("foo"), 1))); + lean_assert(a == name(name("foo"), 1)); + lean_assert(name(name("foo"), 1) == a); + a = mpq(1,3); + lean_assert(a == sexpr(mpq(1,3))); + lean_assert(a == mpq(1,3)); + lean_assert(mpq(1, 3) == a); + lean_assert(mpq(2, 3) != a); + a = power(mpz(2),100); + lean_assert(a == sexpr(power(mpz(2), 100))); + lean_assert(a == power(mpz(2), 100)); + lean_assert(power(mpz(2), 100) == a); + lean_assert(mpq(power(mpz(2), 100)) != a); + lean_assert(sexpr(1, 2) != sexpr(2, 1)); + lean_assert(sexpr(1, 2) != sexpr(1, sexpr(2, nil()))); + lean_assert(sexpr(1, 2) == sexpr(1, sexpr(2))); +} + int main() { continue_on_violation(true); tst1(); + tst2(); return has_violations() ? 1 : 0; } diff --git a/src/util/name.h b/src/util/name.h index 4b6b42b6e..1791b7ae2 100644 --- a/src/util/name.h +++ b/src/util/name.h @@ -21,8 +21,8 @@ class name { name(imp * p); public: name(); - name(char const * name); - name(unsigned k); + explicit name(char const * name); + explicit name(unsigned k); name(name const & prefix, char const * name); name(name const & prefix, unsigned k); name(name const & other);