diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index 30baeaa7f..de00e9ce6 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -20,7 +20,8 @@ unsigned hash_vars(unsigned size, uvar const * vars) { } expr_cell::expr_cell(expr_kind k, unsigned h): - m_kind(k), + m_kind(static_cast(k)), + m_max_shared(0), m_hash(h), m_rc(1) {} @@ -103,7 +104,7 @@ expr_numeral::expr_numeral(mpz const & n): m_numeral(n) {} void expr_cell::dealloc() { - switch (m_kind) { + switch (kind()) { case expr_kind::Var: delete static_cast(this); break; case expr_kind::Constant: delete static_cast(this); break; case expr_kind::App: static_cast(this)->~expr_app(); delete[] reinterpret_cast(this); break; diff --git a/src/kernel/expr.h b/src/kernel/expr.h index 54c46fed3..92ecd444b 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -33,6 +33,8 @@ The main API is divided in the following sections ======================================= */ enum class expr_kind { Var, Constant, App, Lambda, Pi, Prop, Type, Numeral }; +class max_sharing_functor; + /** \brief Base class used to represent expressions. @@ -42,13 +44,17 @@ enum class expr_kind { Var, Constant, App, Lambda, Pi, Prop, Type, Numeral }; */ class expr_cell { protected: - expr_kind m_kind; - unsigned m_hash; + unsigned m_kind:16; + unsigned m_max_shared:1; // flag indicating if the cell has maximally shared subexpressions + unsigned m_hash; MK_LEAN_RC(); // Declare m_rc counter void dealloc(); + bool max_shared() const { return m_max_shared == 1; } + void set_max_shared() { lean_assert(!max_shared()); m_max_shared = 1; } + friend class max_sharing_functor; public: expr_cell(expr_kind k, unsigned h); - expr_kind kind() const { return m_kind; } + expr_kind kind() const { return static_cast(m_kind); } unsigned hash() const { return m_hash; } }; @@ -119,7 +125,7 @@ public: }; // ======================================= -// Expr Representation +// Expr (internal) Representation // 1. Free variables class expr_var : public expr_cell { unsigned m_vidx; // de Bruijn index @@ -225,17 +231,21 @@ inline expr var(unsigned idx) { return expr(new expr_var(idx)); } inline expr constant(name const & n) { return expr(new expr_const(n)); } inline expr constant(name const & n, unsigned pos) { return expr(new expr_const(n, pos)); } expr app(unsigned num_args, expr const * args); -inline expr app(std::initializer_list const & l) { return app(l.size(), l.begin()); } +inline expr app(expr const & e1, expr const & e2) { expr args[2] = {e1, e2}; return app(2, args); } +inline expr app(expr const & e1, expr const & e2, expr const & e3) { expr args[3] = {e1, e2, e3}; return app(3, args); } +inline expr app(expr const & e1, expr const & e2, expr const & e3, expr const & e4) { expr args[4] = {e1, e2, e3, e4}; return app(4, args); } +inline expr app(expr const & e1, expr const & e2, expr const & e3, expr const & e4, expr const & e5) { expr args[5] = {e1, e2, e3, e4, e5}; return app(5, args); } inline expr lambda(name const & n, expr const & t, expr const & e) { return expr(new expr_lambda(n, t, e)); } inline expr pi(name const & n, expr const & t, expr const & e) { return expr(new expr_pi(n, t, e)); } inline expr prop() { return expr(new expr_prop()); } expr type(unsigned size, uvar const * vars); +inline expr type(uvar const & uv) { return type(1, &uv); } inline expr type(std::initializer_list const & l) { return type(l.size(), l.begin()); } inline expr numeral(mpz const & n) { return expr(new expr_numeral(n)); } // ======================================= // ======================================= -// Casting +// Casting (these functions are only needed for low-level code) inline expr_var * to_var(expr_cell * e) { lean_assert(is_var(e)); return static_cast(e); } inline expr_const * to_constant(expr_cell * e) { lean_assert(is_constant(e)); return static_cast(e); } inline expr_app * to_app(expr_cell * e) { lean_assert(is_app(e)); return static_cast(e); } diff --git a/src/kernel/expr_max_shared.cpp b/src/kernel/expr_max_shared.cpp index 7c1189e48..6d0edd0ce 100644 --- a/src/kernel/expr_max_shared.cpp +++ b/src/kernel/expr_max_shared.cpp @@ -6,65 +6,84 @@ Author: Leonardo de Moura */ #include #include -#include "expr.h" +#include "expr_max_shared.h" #include "expr_functors.h" namespace lean { -namespace max_shared_ns { -struct expr_struct_eq { unsigned operator()(expr const & e1, expr const & e2) const { return e1 == e2; }}; -typedef typename std::unordered_set expr_cache; -static thread_local expr_cache g_cache; -expr apply(expr const & a) { - auto r = g_cache.find(a); - if (r != g_cache.end()) - return *r; - switch (a.kind()) { - case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral: - g_cache.insert(a); - return a; - case expr_kind::App: { - std::vector new_args; - bool modified = false; - for (expr const & old_arg : app_args(a)) { - new_args.push_back(apply(old_arg)); - if (!eqp(old_arg, new_args.back())) - modified = true; - } - if (!modified) { - g_cache.insert(a); - return a; - } - else { - expr r = app(static_cast(new_args.size()), new_args.data()); - g_cache.insert(r); - return r; - } + +class max_sharing_functor { + struct expr_struct_eq { unsigned operator()(expr const & e1, expr const & e2) const { return e1 == e2; }}; + typedef typename std::unordered_set expr_cache; + + expr_cache m_cache; + + void cache(expr const & a) { + a.raw()->set_max_shared(); + m_cache.insert(a); } - case expr_kind::Lambda: - case expr_kind::Pi: { - expr const & old_t = get_abs_type(a); - expr const & old_e = get_abs_expr(a); - expr t = apply(old_t); - expr e = apply(old_e); - if (!eqp(t, old_t) || !eqp(e, old_e)) { - name const & n = get_abs_name(a); - expr r = is_pi(a) ? pi(n, t, e) : lambda(n, t, e); - g_cache.insert(r); - return r; - } - else { - g_cache.insert(a); + +public: + + expr apply(expr const & a) { + if (a.raw()->max_shared()) return a; + auto r = m_cache.find(a); + if (r != m_cache.end()) { + lean_assert((*r).raw()->max_shared()); + return *r; } - }} - lean_unreachable(); - return a; -} -} // namespace max_shared_ns -expr max_shared(expr const & a) { - max_shared_ns::g_cache.clear(); - expr r = max_shared_ns::apply(a); - max_shared_ns::g_cache.clear(); - return r; + switch (a.kind()) { + case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral: + cache(a); + return a; + case expr_kind::App: { + std::vector new_args; + bool modified = false; + for (expr const & old_arg : app_args(a)) { + new_args.push_back(apply(old_arg)); + if (!eqp(old_arg, new_args.back())) + modified = true; + } + if (!modified) { + cache(a); + return a; + } + else { + expr r = app(new_args.size(), new_args.data()); + cache(r); + return r; + } + } + case expr_kind::Lambda: + case expr_kind::Pi: { + expr const & old_t = get_abs_type(a); + expr const & old_e = get_abs_expr(a); + expr t = apply(old_t); + expr e = apply(old_e); + if (!eqp(t, old_t) || !eqp(e, old_e)) { + name const & n = get_abs_name(a); + expr r = is_pi(a) ? pi(n, t, e) : lambda(n, t, e); + cache(r); + return r; + } + else { + cache(a); + return a; + } + }} + lean_unreachable(); + return a; + } +}; + +expr max_sharing(expr const & a) { + if (a.raw()->max_shared()) { + return a; + } + else { + max_sharing_functor f; + return f.apply(a); + } } + } // namespace lean diff --git a/src/kernel/expr_max_shared.h b/src/kernel/expr_max_shared.h index 9ee56d4ae..a2bb04a01 100644 --- a/src/kernel/expr_max_shared.h +++ b/src/kernel/expr_max_shared.h @@ -9,7 +9,8 @@ Author: Leonardo de Moura namespace lean { /** - \brief Return a structurally identical expression that is maximally shared. + \brief The resultant expression is structurally identical to the input one, but + it uses maximally shared sub-expressions. */ -expr max_shared(expr const & a); +expr max_sharing(expr const & a); } diff --git a/src/tests/kernel/expr.cpp b/src/tests/kernel/expr.cpp index b988023b8..b5361c7f6 100644 --- a/src/tests/kernel/expr.cpp +++ b/src/tests/kernel/expr.cpp @@ -16,15 +16,17 @@ void tst1() { a = numeral(mpz(10)); expr f; f = var(0); - expr fa = app({f, a}); + expr fa = app(f, a); std::cout << fa << "\n"; - std::cout << app({fa, a}) << "\n"; + std::cout << app(fa, a) << "\n"; lean_assert(eqp(get_arg(fa, 0), f)); lean_assert(eqp(get_arg(fa, 1), a)); - lean_assert(!eqp(fa, app({f, a}))); - lean_assert(app({fa, a}) == app({f, a, a})); - std::cout << app({fa, fa, fa}) << "\n"; + lean_assert(!eqp(fa, app(f, a))); + lean_assert(app(fa, a) == app(f, a, a)); + std::cout << app(fa, fa, fa) << "\n"; std::cout << lambda(name("x"), prop(), var(0)) << "\n"; + lean_assert(app(app(f, a), a) == app(f, a, a)); + lean_assert(app(f, app(a, a)) != app(f, a, a)); } expr mk_dag(unsigned depth) { @@ -32,7 +34,7 @@ expr mk_dag(unsigned depth) { expr a = var(0); while (depth > 0) { depth--; - a = app({f, a, a}); + a = app(f, a, a); } return a; } @@ -110,7 +112,7 @@ expr mk_big(expr f, unsigned depth, unsigned val) { if (depth == 1) return var(val); else - return app({f, mk_big(f, depth - 1, val << 1), mk_big(f, depth - 1, (val << 1) + 1)}); + return app(f, mk_big(f, depth - 1, val << 1), mk_big(f, depth - 1, (val << 1) + 1)); } void tst3() { @@ -124,7 +126,7 @@ void tst4() { expr f = constant(name("f")); expr a = var(0); for (unsigned i = 0; i < 10000; i++) { - a = app({f, a}); + a = app(f, a); } } @@ -132,7 +134,7 @@ expr mk_redundant_dag(expr f, unsigned depth) { if (depth == 0) return var(0); else - return app({f, mk_redundant_dag(f, depth - 1), mk_redundant_dag(f, depth - 1)}); + return app(f, mk_redundant_dag(f, depth - 1), mk_redundant_dag(f, depth - 1)); } @@ -161,20 +163,32 @@ void tst5() { expr f = constant(name("f")); { expr r1 = mk_redundant_dag(f, 5); - expr r2 = max_shared(r1); + expr r2 = max_sharing(r1); std::cout << "count(r1): " << count(r1) << "\n"; std::cout << "count(r2): " << count(r2) << "\n"; lean_assert(r1 == r2); } { expr r1 = mk_redundant_dag(f, 16); - expr r2 = max_shared(r1); + expr r2 = max_sharing(r1); lean_assert(r1 == r2); } } +void tst6() { + expr f = constant(name("f")); + expr r = mk_redundant_dag(f, 12); + for (unsigned i = 0; i < 1000; i++) { + r = max_sharing(r); + } + r = mk_big(f, 16, 0); + for (unsigned i = 0; i < 1000000; i++) { + r = max_sharing(r); + } +} + int main() { - // continue_on_violation(true); + continue_on_violation(true); std::cout << "sizeof(expr): " << sizeof(expr) << "\n"; std::cout << "sizeof(expr_app): " << sizeof(expr_app) << "\n"; tst1(); @@ -182,6 +196,7 @@ int main() { tst3(); tst4(); tst5(); + tst6(); std::cout << "done" << "\n"; return has_violations() ? 1 : 0; }