From 06320c861557078e3a423ba7ae6714594fe38bcd Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 22 Jul 2013 16:40:17 -0700 Subject: [PATCH] Replace expr == with recursive function. Add goodies for traversing expressions. Signed-off-by: Leonardo de Moura --- src/kernel/expr.cpp | 153 ++++++++++++++------------------------ src/kernel/expr.h | 20 +++++ src/tests/kernel/expr.cpp | 83 ++++++++++++++++++++- 3 files changed, 155 insertions(+), 101 deletions(-) diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index 09c2bb1e0..8ab1cafd0 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -115,97 +115,62 @@ void expr_cell::dealloc() { } } -bool operator==(expr const & a, expr const & b) { - if (eqp(a, b)) - return true; - if (a.hash() != b.hash() || a.kind() != b.kind()) - return false; - static thread_local std::vector todo; - static thread_local expr_cell_pair_set visited; - auto visit = [&](expr_cell * a, expr_cell * b) -> bool { - if (a == b) +namespace expr_eq { +static thread_local expr_cell_pair_set g_eq_visited; +bool eq(expr const & a, expr const & b) { + if (eqp(a, b)) return true; + if (a.hash() != b.hash()) return false; + if (a.kind() != b.kind()) return false; + if (is_var(a)) return get_var_idx(a) == get_var_idx(b); + if (is_prop(a)) return true; + if (get_rc(a) > 1 && get_rc(b) > 1) { + auto p = std::make_pair(a.raw(), b.raw()); + if (g_eq_visited.find(p) != g_eq_visited.end()) return true; - if (a->hash() != b->hash()) - return false; - if (a->kind() != b->kind()) - return false; - if (a->kind() == expr_kind::Prop) - return true; - if (a->kind() == expr_kind::Var) - return get_var_idx(a) == get_var_idx(b); - expr_cell_pair p(a, b); - if (visited.find(p) != visited.end()) - return true; - todo.push_back(p); - visited.insert(p); - return true; - }; - todo.clear(); - visited.clear(); - visit(a.raw(), b.raw()); - while (!todo.empty()) { - auto p = todo.back(); - expr_cell * a = p.first; - expr_cell * b = p.second; - todo.pop_back(); - lean_assert(a != b); - lean_assert(a->hash() == b->hash()); - lean_assert(a->kind() == b->kind()); - switch (a->kind()) { - case expr_kind::Var: - lean_unreachable(); - break; - case expr_kind::Constant: - if (get_const_name(a) != get_const_name(b)) - return false; - break; - case expr_kind::App: - if (get_num_args(a) != get_num_args(b)) - return false; - for (unsigned i = 0; i < get_num_args(a); i++) { - if (!visit(get_arg(a, i).raw(), get_arg(b, i).raw())) - return false; - } - break; - case expr_kind::Lambda: - case expr_kind::Pi: - // Lambda and Pi - // Remark: we ignore get_abs_name because we want alpha-equivalence - if (!visit(get_abs_type(a).raw(), get_abs_type(b).raw()) || - !visit(get_abs_expr(a).raw(), get_abs_expr(b).raw())) - return false; - break; - case expr_kind::Prop: - lean_unreachable(); - break; - case expr_kind::Type: - if (get_ty_num_vars(a) != get_ty_num_vars(b)) - return false; - for (unsigned i = 0; i < get_ty_num_vars(a); i++) { - uvar v1 = get_ty_var(a, i); - uvar v2 = get_ty_var(b, i); - if (v1.first != v2.first || v1.second != v2.second) - return false; - } - break; - case expr_kind::Numeral: - if (get_numeral(a) != get_numeral(b)) - return false; - break; - } + g_eq_visited.insert(p); } - return true; + switch (a.kind()) { + case expr_kind::Var: lean_unreachable(); return true; + case expr_kind::Constant: return get_const_name(a) == get_const_name(b); + case expr_kind::App: + if (get_num_args(a) != get_num_args(b)) + return false; + for (unsigned i = 0; i < get_num_args(a); i++) + if (!eq(get_arg(a, i), get_arg(b, i))) + return false; + return true; + case expr_kind::Lambda: + case expr_kind::Pi: + // Lambda and Pi + // Remark: we ignore get_abs_name because we want alpha-equivalence + return eq(get_abs_type(a), get_abs_type(b)) && eq(get_abs_expr(a), get_abs_expr(b)); + case expr_kind::Prop: lean_unreachable(); return true; + case expr_kind::Type: + if (get_ty_num_vars(a) != get_ty_num_vars(b)) + return false; + for (unsigned i = 0; i < get_ty_num_vars(a); i++) { + uvar v1 = get_ty_var(a, i); + uvar v2 = get_ty_var(b, i); + if (v1.first != v2.first || v1.second != v2.second) + return false; + } + return true; + case expr_kind::Numeral: return get_numeral(a) == get_numeral(b); + } + lean_unreachable(); + return false; +} +} // namespace expr_eq +bool operator==(expr const & a, expr const & b) { + expr_eq::g_eq_visited.clear(); + return expr_eq::eq(a, b); } // Low-level pretty printer std::ostream & operator<<(std::ostream & out, expr const & a) { switch (a.kind()) { - case expr_kind::Var: - out << "#" << get_var_idx(a); - break; - case expr_kind::Constant: - out << get_const_name(a); - break; + case expr_kind::Var: out << "#" << get_var_idx(a); break; + case expr_kind::Constant: out << get_const_name(a); break; case expr_kind::App: out << "("; for (unsigned i = 0; i < get_num_args(a); i++) { @@ -214,21 +179,11 @@ std::ostream & operator<<(std::ostream & out, expr const & a) { } out << ")"; break; - case expr_kind::Lambda: - out << "(fun (" << get_abs_name(a) << " : " << get_abs_type(a) << ") " << get_abs_expr(a) << ")"; - break; - case expr_kind::Pi: - out << "(forall (" << get_abs_name(a) << " : " << get_abs_type(a) << ") " << get_abs_expr(a) << ")"; - break; - case expr_kind::Prop: - out << "Prop"; - break; - case expr_kind::Type: - out << "Type"; - break; - case expr_kind::Numeral: - out << get_numeral(a); - break; + case expr_kind::Lambda: out << "(fun (" << get_abs_name(a) << " : " << get_abs_type(a) << ") " << get_abs_expr(a) << ")"; break; + case expr_kind::Pi: out << "(forall (" << get_abs_name(a) << " : " << get_abs_type(a) << ") " << get_abs_expr(a) << ")"; break; + case expr_kind::Prop: out << "Prop"; break; + case expr_kind::Type: out << "Type"; break; + case expr_kind::Numeral: out << get_numeral(a); break; } return out; } diff --git a/src/kernel/expr.h b/src/kernel/expr.h index df912f869..0e5ab667b 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -146,6 +146,8 @@ public: ~expr_app(); unsigned get_num_args() const { return m_num_args; } expr const & get_arg(unsigned idx) const { lean_assert(idx < m_num_args); return m_args[idx]; } + expr const * begin_args() const { return m_args; } + expr const * end_args() const { return m_args + m_num_args; } }; // 4. Abstraction class expr_abstraction : public expr_cell { @@ -257,6 +259,7 @@ inline expr_numeral * to_numeral(expr const & e) { return to_numeral(e.r // ======================================= // Accessors +inline unsigned get_rc(expr_cell * e) { return e->get_rc(); } inline unsigned get_var_idx(expr_cell * e) { return to_var(e)->get_vidx(); } inline name const & get_const_name(expr_cell * e) { return to_constant(e)->get_name(); } inline unsigned get_const_pos(expr_cell * e) { return to_constant(e)->get_pos(); } @@ -269,11 +272,14 @@ inline unsigned get_ty_num_vars(expr_cell * e) { return to_type(e)- inline uvar const & get_ty_var(expr_cell * e, unsigned idx) { return to_type(e)->get_var(idx); } inline mpz const & get_numeral(expr_cell * e) { return to_numeral(e)->get_num(); } +inline unsigned get_rc(expr const & e) { return e.raw()->get_rc(); } inline unsigned get_var_idx(expr const & e) { return to_var(e)->get_vidx(); } inline name const & get_const_name(expr const & e) { return to_constant(e)->get_name(); } inline unsigned get_const_pos(expr const & e) { return to_constant(e)->get_pos(); } inline unsigned get_num_args(expr const & e) { return to_app(e)->get_num_args(); } inline expr const & get_arg(expr const & e, unsigned idx) { return to_app(e)->get_arg(idx); } +inline expr const * begin_args(expr const & e) { return to_app(e)->begin_args(); } +inline expr const * end_args(expr const & e) { return to_app(e)->end_args(); } inline name const & get_abs_name(expr const & e) { return to_abstraction(e)->get_name(); } inline expr const & get_abs_type(expr const & e) { return to_abstraction(e)->get_type(); } inline expr const & get_abs_expr(expr const & e) { return to_abstraction(e)->get_expr(); } @@ -290,4 +296,18 @@ inline bool operator!=(expr const & a, expr const & b) { return !operator==(a, b std::ostream & operator<<(std::ostream & out, expr const & a); +/** + \brief Wrapper for iterating over application arguments. + If n is an application, it allows us to write + for (expr const & arg : app_args(n)) { + ... do something with argument + } +*/ +struct app_args { + expr const & m_app; + app_args(expr const & a):m_app(a) { lean_assert(is_app(a)); } + expr const * begin() const { return &get_arg(m_app, 0); } + expr const * end() const { return begin() + get_num_args(m_app); } +}; + } diff --git a/src/tests/kernel/expr.cpp b/src/tests/kernel/expr.cpp index 66a939477..7f6f559e7 100644 --- a/src/tests/kernel/expr.cpp +++ b/src/tests/kernel/expr.cpp @@ -6,6 +6,7 @@ Author: Leonardo de Moura */ #include "expr.h" #include "test.h" +#include using namespace lean; static void tst1() { @@ -34,9 +35,86 @@ expr mk_dag(unsigned depth) { return a; } +unsigned depth1(expr const & e) { + switch (e.kind()) { + case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral: + return 1; + case expr_kind::App: { + unsigned m = 0; + for (expr const & a : app_args(e)) + m = std::max(m, depth1(a)); + return m + 1; + } + case expr_kind::Lambda: case expr_kind::Pi: + return std::max(depth1(get_abs_type(e)), depth1(get_abs_expr(e))) + 1; + } + return 0; +} + +// This is the fastest depth implementation in this file. +unsigned depth2(expr const & e) { + switch (e.kind()) { + case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral: + return 1; + case expr_kind::App: + return + std::accumulate(begin_args(e), end_args(e), 0, + [](unsigned m, expr const & arg){ return std::max(depth2(arg), m); }) + + 1; + case expr_kind::Lambda: case expr_kind::Pi: + return std::max(depth2(get_abs_type(e)), depth2(get_abs_expr(e))) + 1; + } + return 0; +} + +// This is the slowest depth implementation in this file. +unsigned depth3(expr const & e) { + static std::vector> todo; + unsigned m = 0; + todo.push_back(std::make_pair(&e, 0)); + while (!todo.empty()) { + auto const & p = todo.back(); + expr const & e = *(p.first); + unsigned c = p.second + 1; + todo.pop_back(); + switch (e.kind()) { + case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral: + m = std::max(c, m); + break; + case expr_kind::App: { + unsigned num = get_num_args(e); + for (unsigned i = 0; i < num; i++) + todo.push_back(std::make_pair(&get_arg(e, i), c)); + break; + } + case expr_kind::Lambda: case expr_kind::Pi: + todo.push_back(std::make_pair(&get_abs_type(e), c)); + todo.push_back(std::make_pair(&get_abs_expr(e), c)); + break; + } + } + return m; +} + static void tst2() { - expr r1 = mk_dag(24); - expr r2 = mk_dag(24); + expr r1 = mk_dag(20); + expr r2 = mk_dag(20); + lean_verify(r1 == r2); + std::cout << depth2(r1) << "\n"; + lean_verify(depth2(r1) == 21); +} + +expr mk_big(expr f, unsigned depth, unsigned val) { + if (depth == 1) + return var(val); + else + return app({f, mk_big(f, depth - 1, val << 1), mk_big(f, depth - 1, (val << 1) + 1)}); +} + +static void tst3() { + expr f = constant(name("f")); + expr r1 = mk_big(f, 18, 0); + expr r2 = mk_big(f, 18, 0); lean_verify(r1 == r2); } @@ -46,5 +124,6 @@ int main() { std::cout << "sizeof(expr_app): " << sizeof(expr_app) << "\n"; tst1(); tst2(); + tst3(); return has_violations() ? 1 : 0; }