diff --git a/src/kernel/CMakeLists.txt b/src/kernel/CMakeLists.txt index 86a769aca..ea1516bab 100644 --- a/src/kernel/CMakeLists.txt +++ b/src/kernel/CMakeLists.txt @@ -1,2 +1,2 @@ -add_library(kernel expr.cpp) +add_library(kernel expr.cpp expr_max_shared) target_link_libraries(kernel ${EXTRA_LIBS}) diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index 8ab1cafd0..b9390a062 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -123,7 +123,7 @@ bool eq(expr const & a, expr const & b) { if (a.kind() != b.kind()) return false; if (is_var(a)) return get_var_idx(a) == get_var_idx(b); if (is_prop(a)) return true; - if (get_rc(a) > 1 && get_rc(b) > 1) { + if (is_shared(a) && is_shared(b)) { auto p = std::make_pair(a.raw(), b.raw()); if (g_eq_visited.find(p) != g_eq_visited.end()) return true; diff --git a/src/kernel/expr.h b/src/kernel/expr.h index 0e5ab667b..54c46fed3 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -260,6 +260,7 @@ inline expr_numeral * to_numeral(expr const & e) { return to_numeral(e.r // ======================================= // Accessors inline unsigned get_rc(expr_cell * e) { return e->get_rc(); } +inline bool is_shared(expr_cell * e) { return get_rc(e) > 1; } inline unsigned get_var_idx(expr_cell * e) { return to_var(e)->get_vidx(); } inline name const & get_const_name(expr_cell * e) { return to_constant(e)->get_name(); } inline unsigned get_const_pos(expr_cell * e) { return to_constant(e)->get_pos(); } @@ -273,6 +274,7 @@ inline uvar const & get_ty_var(expr_cell * e, unsigned idx) { return to_type(e)- inline mpz const & get_numeral(expr_cell * e) { return to_numeral(e)->get_num(); } inline unsigned get_rc(expr const & e) { return e.raw()->get_rc(); } +inline bool is_shared(expr const & e) { return get_rc(e) > 1; } inline unsigned get_var_idx(expr const & e) { return to_var(e)->get_vidx(); } inline name const & get_const_name(expr const & e) { return to_constant(e)->get_name(); } inline unsigned get_const_pos(expr const & e) { return to_constant(e)->get_pos(); } diff --git a/src/kernel/expr_set.h b/src/kernel/expr_set.h index 0e7d7d7ef..e85ddfc5c 100644 --- a/src/kernel/expr_set.h +++ b/src/kernel/expr_set.h @@ -7,6 +7,7 @@ Author: Leonardo de Moura #pragma once #include #include "expr.h" +#include "expr_functors.h" #include "hash.h" namespace lean { @@ -14,12 +15,6 @@ namespace lean { // ======================================= // Expression Set // Remark: to expressions are assumed to be equal if they are "pointer-equal" -struct expr_hash { - unsigned operator()(expr const & e) const { return e.hash(); } -}; -struct expr_eqp { - bool operator()(expr const & e1, expr const & e2) const { return eqp(e1, e2); } -}; typedef std::unordered_set expr_set; // ======================================= @@ -30,12 +25,6 @@ typedef std::unordered_set expr_set; // WARNING: use with care, this kind of set // does not prevent an expression from being // garbage collected. -struct expr_cell_hash { - unsigned operator()(expr_cell * e) const { return e->hash(); } -}; -struct expr_cell_eqp { - bool operator()(expr_cell * e1, expr_cell * e2) const { return e1 == e2; } -}; typedef std::unordered_set expr_cell_set; // ======================================= diff --git a/src/tests/kernel/expr.cpp b/src/tests/kernel/expr.cpp index be655a5ae..47a1271f4 100644 --- a/src/tests/kernel/expr.cpp +++ b/src/tests/kernel/expr.cpp @@ -5,11 +5,13 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include "expr.h" +#include "expr_max_shared.h" +#include "expr_set.h" #include "test.h" #include using namespace lean; -static void tst1() { +void tst1() { expr a; a = numeral(mpz(10)); expr f; @@ -96,7 +98,7 @@ unsigned depth3(expr const & e) { return m; } -static void tst2() { +void tst2() { expr r1 = mk_dag(20); expr r2 = mk_dag(20); lean_verify(r1 == r2); @@ -111,14 +113,14 @@ expr mk_big(expr f, unsigned depth, unsigned val) { return app({f, mk_big(f, depth - 1, val << 1), mk_big(f, depth - 1, (val << 1) + 1)}); } -static void tst3() { +void tst3() { expr f = constant(name("f")); expr r1 = mk_big(f, 18, 0); expr r2 = mk_big(f, 18, 0); lean_verify(r1 == r2); } -static void tst4() { +void tst4() { expr f = constant(name("f")); expr a = var(0); for (unsigned i = 0; i < 10000; i++) { @@ -126,14 +128,55 @@ static void tst4() { } } +expr mk_redundant_dag(expr f, unsigned depth) { + if (depth == 0) + return var(0); + else + return app({f, mk_redundant_dag(f, depth - 1), mk_redundant_dag(f, depth - 1)}); +} + + +unsigned count_core(expr const & a, expr_set & s) { + if (s.find(a) != s.end()) + return 0; + s.insert(a); + switch (a.kind()) { + case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral: + return 1; + case expr_kind::App: + return std::accumulate(begin_args(a), end_args(a), 1, + [&](unsigned sum, expr const & arg){ return sum + count_core(arg, s); }); + case expr_kind::Lambda: case expr_kind::Pi: + return count_core(get_abs_type(a), s) + count_core(get_abs_expr(a), s) + 1; + } + return 0; +} + +unsigned count(expr const & a) { + expr_set s; + return count_core(a, s); +} + +void tst5() { + expr f = constant(name("f")); + expr r1 = mk_redundant_dag(f, 1); + expr r2 = max_shared(r1); + std::cout << "r1: " << r1 << "\n"; + std::cout << "r2: " << r1 << "\n"; + std::cout << "count(r1): " << count(r1) << "\n"; + std::cout << "count(r2): " << count(r2) << "\n"; + lean_assert(r1 == r2); +} + int main() { // continue_on_violation(true); std::cout << "sizeof(expr): " << sizeof(expr) << "\n"; std::cout << "sizeof(expr_app): " << sizeof(expr_app) << "\n"; - tst1(); - tst2(); - tst3(); - tst4(); + // tst1(); + // tst2(); + // tst3(); + // tst4(); + tst5(); std::cout << "done" << "\n"; return has_violations() ? 1 : 0; }