Replace expr == with recursive function. Add goodies for traversing expressions.

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-07-22 16:40:17 -07:00
parent c4cd6c4f84
commit 06320c8615
3 changed files with 155 additions and 101 deletions

View file

@ -115,97 +115,62 @@ void expr_cell::dealloc() {
}
}
bool operator==(expr const & a, expr const & b) {
if (eqp(a, b))
return true;
if (a.hash() != b.hash() || a.kind() != b.kind())
return false;
static thread_local std::vector<expr_cell_pair> todo;
static thread_local expr_cell_pair_set visited;
auto visit = [&](expr_cell * a, expr_cell * b) -> bool {
if (a == b)
namespace expr_eq {
static thread_local expr_cell_pair_set g_eq_visited;
bool eq(expr const & a, expr const & b) {
if (eqp(a, b)) return true;
if (a.hash() != b.hash()) return false;
if (a.kind() != b.kind()) return false;
if (is_var(a)) return get_var_idx(a) == get_var_idx(b);
if (is_prop(a)) return true;
if (get_rc(a) > 1 && get_rc(b) > 1) {
auto p = std::make_pair(a.raw(), b.raw());
if (g_eq_visited.find(p) != g_eq_visited.end())
return true;
if (a->hash() != b->hash())
return false;
if (a->kind() != b->kind())
return false;
if (a->kind() == expr_kind::Prop)
return true;
if (a->kind() == expr_kind::Var)
return get_var_idx(a) == get_var_idx(b);
expr_cell_pair p(a, b);
if (visited.find(p) != visited.end())
return true;
todo.push_back(p);
visited.insert(p);
return true;
};
todo.clear();
visited.clear();
visit(a.raw(), b.raw());
while (!todo.empty()) {
auto p = todo.back();
expr_cell * a = p.first;
expr_cell * b = p.second;
todo.pop_back();
lean_assert(a != b);
lean_assert(a->hash() == b->hash());
lean_assert(a->kind() == b->kind());
switch (a->kind()) {
case expr_kind::Var:
lean_unreachable();
break;
case expr_kind::Constant:
if (get_const_name(a) != get_const_name(b))
return false;
break;
case expr_kind::App:
if (get_num_args(a) != get_num_args(b))
return false;
for (unsigned i = 0; i < get_num_args(a); i++) {
if (!visit(get_arg(a, i).raw(), get_arg(b, i).raw()))
return false;
}
break;
case expr_kind::Lambda:
case expr_kind::Pi:
// Lambda and Pi
// Remark: we ignore get_abs_name because we want alpha-equivalence
if (!visit(get_abs_type(a).raw(), get_abs_type(b).raw()) ||
!visit(get_abs_expr(a).raw(), get_abs_expr(b).raw()))
return false;
break;
case expr_kind::Prop:
lean_unreachable();
break;
case expr_kind::Type:
if (get_ty_num_vars(a) != get_ty_num_vars(b))
return false;
for (unsigned i = 0; i < get_ty_num_vars(a); i++) {
uvar v1 = get_ty_var(a, i);
uvar v2 = get_ty_var(b, i);
if (v1.first != v2.first || v1.second != v2.second)
return false;
}
break;
case expr_kind::Numeral:
if (get_numeral(a) != get_numeral(b))
return false;
break;
}
g_eq_visited.insert(p);
}
return true;
switch (a.kind()) {
case expr_kind::Var: lean_unreachable(); return true;
case expr_kind::Constant: return get_const_name(a) == get_const_name(b);
case expr_kind::App:
if (get_num_args(a) != get_num_args(b))
return false;
for (unsigned i = 0; i < get_num_args(a); i++)
if (!eq(get_arg(a, i), get_arg(b, i)))
return false;
return true;
case expr_kind::Lambda:
case expr_kind::Pi:
// Lambda and Pi
// Remark: we ignore get_abs_name because we want alpha-equivalence
return eq(get_abs_type(a), get_abs_type(b)) && eq(get_abs_expr(a), get_abs_expr(b));
case expr_kind::Prop: lean_unreachable(); return true;
case expr_kind::Type:
if (get_ty_num_vars(a) != get_ty_num_vars(b))
return false;
for (unsigned i = 0; i < get_ty_num_vars(a); i++) {
uvar v1 = get_ty_var(a, i);
uvar v2 = get_ty_var(b, i);
if (v1.first != v2.first || v1.second != v2.second)
return false;
}
return true;
case expr_kind::Numeral: return get_numeral(a) == get_numeral(b);
}
lean_unreachable();
return false;
}
} // namespace expr_eq
bool operator==(expr const & a, expr const & b) {
expr_eq::g_eq_visited.clear();
return expr_eq::eq(a, b);
}
// Low-level pretty printer
std::ostream & operator<<(std::ostream & out, expr const & a) {
switch (a.kind()) {
case expr_kind::Var:
out << "#" << get_var_idx(a);
break;
case expr_kind::Constant:
out << get_const_name(a);
break;
case expr_kind::Var: out << "#" << get_var_idx(a); break;
case expr_kind::Constant: out << get_const_name(a); break;
case expr_kind::App:
out << "(";
for (unsigned i = 0; i < get_num_args(a); i++) {
@ -214,21 +179,11 @@ std::ostream & operator<<(std::ostream & out, expr const & a) {
}
out << ")";
break;
case expr_kind::Lambda:
out << "(fun (" << get_abs_name(a) << " : " << get_abs_type(a) << ") " << get_abs_expr(a) << ")";
break;
case expr_kind::Pi:
out << "(forall (" << get_abs_name(a) << " : " << get_abs_type(a) << ") " << get_abs_expr(a) << ")";
break;
case expr_kind::Prop:
out << "Prop";
break;
case expr_kind::Type:
out << "Type";
break;
case expr_kind::Numeral:
out << get_numeral(a);
break;
case expr_kind::Lambda: out << "(fun (" << get_abs_name(a) << " : " << get_abs_type(a) << ") " << get_abs_expr(a) << ")"; break;
case expr_kind::Pi: out << "(forall (" << get_abs_name(a) << " : " << get_abs_type(a) << ") " << get_abs_expr(a) << ")"; break;
case expr_kind::Prop: out << "Prop"; break;
case expr_kind::Type: out << "Type"; break;
case expr_kind::Numeral: out << get_numeral(a); break;
}
return out;
}

View file

@ -146,6 +146,8 @@ public:
~expr_app();
unsigned get_num_args() const { return m_num_args; }
expr const & get_arg(unsigned idx) const { lean_assert(idx < m_num_args); return m_args[idx]; }
expr const * begin_args() const { return m_args; }
expr const * end_args() const { return m_args + m_num_args; }
};
// 4. Abstraction
class expr_abstraction : public expr_cell {
@ -257,6 +259,7 @@ inline expr_numeral * to_numeral(expr const & e) { return to_numeral(e.r
// =======================================
// Accessors
inline unsigned get_rc(expr_cell * e) { return e->get_rc(); }
inline unsigned get_var_idx(expr_cell * e) { return to_var(e)->get_vidx(); }
inline name const & get_const_name(expr_cell * e) { return to_constant(e)->get_name(); }
inline unsigned get_const_pos(expr_cell * e) { return to_constant(e)->get_pos(); }
@ -269,11 +272,14 @@ inline unsigned get_ty_num_vars(expr_cell * e) { return to_type(e)-
inline uvar const & get_ty_var(expr_cell * e, unsigned idx) { return to_type(e)->get_var(idx); }
inline mpz const & get_numeral(expr_cell * e) { return to_numeral(e)->get_num(); }
inline unsigned get_rc(expr const & e) { return e.raw()->get_rc(); }
inline unsigned get_var_idx(expr const & e) { return to_var(e)->get_vidx(); }
inline name const & get_const_name(expr const & e) { return to_constant(e)->get_name(); }
inline unsigned get_const_pos(expr const & e) { return to_constant(e)->get_pos(); }
inline unsigned get_num_args(expr const & e) { return to_app(e)->get_num_args(); }
inline expr const & get_arg(expr const & e, unsigned idx) { return to_app(e)->get_arg(idx); }
inline expr const * begin_args(expr const & e) { return to_app(e)->begin_args(); }
inline expr const * end_args(expr const & e) { return to_app(e)->end_args(); }
inline name const & get_abs_name(expr const & e) { return to_abstraction(e)->get_name(); }
inline expr const & get_abs_type(expr const & e) { return to_abstraction(e)->get_type(); }
inline expr const & get_abs_expr(expr const & e) { return to_abstraction(e)->get_expr(); }
@ -290,4 +296,18 @@ inline bool operator!=(expr const & a, expr const & b) { return !operator==(a, b
std::ostream & operator<<(std::ostream & out, expr const & a);
/**
\brief Wrapper for iterating over application arguments.
If n is an application, it allows us to write
for (expr const & arg : app_args(n)) {
... do something with argument
}
*/
struct app_args {
expr const & m_app;
app_args(expr const & a):m_app(a) { lean_assert(is_app(a)); }
expr const * begin() const { return &get_arg(m_app, 0); }
expr const * end() const { return begin() + get_num_args(m_app); }
};
}

View file

@ -6,6 +6,7 @@ Author: Leonardo de Moura
*/
#include "expr.h"
#include "test.h"
#include <algorithm>
using namespace lean;
static void tst1() {
@ -34,9 +35,86 @@ expr mk_dag(unsigned depth) {
return a;
}
unsigned depth1(expr const & e) {
switch (e.kind()) {
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral:
return 1;
case expr_kind::App: {
unsigned m = 0;
for (expr const & a : app_args(e))
m = std::max(m, depth1(a));
return m + 1;
}
case expr_kind::Lambda: case expr_kind::Pi:
return std::max(depth1(get_abs_type(e)), depth1(get_abs_expr(e))) + 1;
}
return 0;
}
// This is the fastest depth implementation in this file.
unsigned depth2(expr const & e) {
switch (e.kind()) {
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral:
return 1;
case expr_kind::App:
return
std::accumulate(begin_args(e), end_args(e), 0,
[](unsigned m, expr const & arg){ return std::max(depth2(arg), m); })
+ 1;
case expr_kind::Lambda: case expr_kind::Pi:
return std::max(depth2(get_abs_type(e)), depth2(get_abs_expr(e))) + 1;
}
return 0;
}
// This is the slowest depth implementation in this file.
unsigned depth3(expr const & e) {
static std::vector<std::pair<expr const *, unsigned>> todo;
unsigned m = 0;
todo.push_back(std::make_pair(&e, 0));
while (!todo.empty()) {
auto const & p = todo.back();
expr const & e = *(p.first);
unsigned c = p.second + 1;
todo.pop_back();
switch (e.kind()) {
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral:
m = std::max(c, m);
break;
case expr_kind::App: {
unsigned num = get_num_args(e);
for (unsigned i = 0; i < num; i++)
todo.push_back(std::make_pair(&get_arg(e, i), c));
break;
}
case expr_kind::Lambda: case expr_kind::Pi:
todo.push_back(std::make_pair(&get_abs_type(e), c));
todo.push_back(std::make_pair(&get_abs_expr(e), c));
break;
}
}
return m;
}
static void tst2() {
expr r1 = mk_dag(24);
expr r2 = mk_dag(24);
expr r1 = mk_dag(20);
expr r2 = mk_dag(20);
lean_verify(r1 == r2);
std::cout << depth2(r1) << "\n";
lean_verify(depth2(r1) == 21);
}
expr mk_big(expr f, unsigned depth, unsigned val) {
if (depth == 1)
return var(val);
else
return app({f, mk_big(f, depth - 1, val << 1), mk_big(f, depth - 1, (val << 1) + 1)});
}
static void tst3() {
expr f = constant(name("f"));
expr r1 = mk_big(f, 18, 0);
expr r2 = mk_big(f, 18, 0);
lean_verify(r1 == r2);
}
@ -46,5 +124,6 @@ int main() {
std::cout << "sizeof(expr_app): " << sizeof(expr_app) << "\n";
tst1();
tst2();
tst3();
return has_violations() ? 1 : 0;
}