Replace expr == with recursive function. Add goodies for traversing expressions.
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
c4cd6c4f84
commit
06320c8615
3 changed files with 155 additions and 101 deletions
|
@ -115,97 +115,62 @@ void expr_cell::dealloc() {
|
|||
}
|
||||
}
|
||||
|
||||
bool operator==(expr const & a, expr const & b) {
|
||||
if (eqp(a, b))
|
||||
return true;
|
||||
if (a.hash() != b.hash() || a.kind() != b.kind())
|
||||
return false;
|
||||
static thread_local std::vector<expr_cell_pair> todo;
|
||||
static thread_local expr_cell_pair_set visited;
|
||||
auto visit = [&](expr_cell * a, expr_cell * b) -> bool {
|
||||
if (a == b)
|
||||
namespace expr_eq {
|
||||
static thread_local expr_cell_pair_set g_eq_visited;
|
||||
bool eq(expr const & a, expr const & b) {
|
||||
if (eqp(a, b)) return true;
|
||||
if (a.hash() != b.hash()) return false;
|
||||
if (a.kind() != b.kind()) return false;
|
||||
if (is_var(a)) return get_var_idx(a) == get_var_idx(b);
|
||||
if (is_prop(a)) return true;
|
||||
if (get_rc(a) > 1 && get_rc(b) > 1) {
|
||||
auto p = std::make_pair(a.raw(), b.raw());
|
||||
if (g_eq_visited.find(p) != g_eq_visited.end())
|
||||
return true;
|
||||
if (a->hash() != b->hash())
|
||||
return false;
|
||||
if (a->kind() != b->kind())
|
||||
return false;
|
||||
if (a->kind() == expr_kind::Prop)
|
||||
return true;
|
||||
if (a->kind() == expr_kind::Var)
|
||||
return get_var_idx(a) == get_var_idx(b);
|
||||
expr_cell_pair p(a, b);
|
||||
if (visited.find(p) != visited.end())
|
||||
return true;
|
||||
todo.push_back(p);
|
||||
visited.insert(p);
|
||||
return true;
|
||||
};
|
||||
todo.clear();
|
||||
visited.clear();
|
||||
visit(a.raw(), b.raw());
|
||||
while (!todo.empty()) {
|
||||
auto p = todo.back();
|
||||
expr_cell * a = p.first;
|
||||
expr_cell * b = p.second;
|
||||
todo.pop_back();
|
||||
lean_assert(a != b);
|
||||
lean_assert(a->hash() == b->hash());
|
||||
lean_assert(a->kind() == b->kind());
|
||||
switch (a->kind()) {
|
||||
case expr_kind::Var:
|
||||
lean_unreachable();
|
||||
break;
|
||||
case expr_kind::Constant:
|
||||
if (get_const_name(a) != get_const_name(b))
|
||||
return false;
|
||||
break;
|
||||
case expr_kind::App:
|
||||
if (get_num_args(a) != get_num_args(b))
|
||||
return false;
|
||||
for (unsigned i = 0; i < get_num_args(a); i++) {
|
||||
if (!visit(get_arg(a, i).raw(), get_arg(b, i).raw()))
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case expr_kind::Lambda:
|
||||
case expr_kind::Pi:
|
||||
// Lambda and Pi
|
||||
// Remark: we ignore get_abs_name because we want alpha-equivalence
|
||||
if (!visit(get_abs_type(a).raw(), get_abs_type(b).raw()) ||
|
||||
!visit(get_abs_expr(a).raw(), get_abs_expr(b).raw()))
|
||||
return false;
|
||||
break;
|
||||
case expr_kind::Prop:
|
||||
lean_unreachable();
|
||||
break;
|
||||
case expr_kind::Type:
|
||||
if (get_ty_num_vars(a) != get_ty_num_vars(b))
|
||||
return false;
|
||||
for (unsigned i = 0; i < get_ty_num_vars(a); i++) {
|
||||
uvar v1 = get_ty_var(a, i);
|
||||
uvar v2 = get_ty_var(b, i);
|
||||
if (v1.first != v2.first || v1.second != v2.second)
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case expr_kind::Numeral:
|
||||
if (get_numeral(a) != get_numeral(b))
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
g_eq_visited.insert(p);
|
||||
}
|
||||
return true;
|
||||
switch (a.kind()) {
|
||||
case expr_kind::Var: lean_unreachable(); return true;
|
||||
case expr_kind::Constant: return get_const_name(a) == get_const_name(b);
|
||||
case expr_kind::App:
|
||||
if (get_num_args(a) != get_num_args(b))
|
||||
return false;
|
||||
for (unsigned i = 0; i < get_num_args(a); i++)
|
||||
if (!eq(get_arg(a, i), get_arg(b, i)))
|
||||
return false;
|
||||
return true;
|
||||
case expr_kind::Lambda:
|
||||
case expr_kind::Pi:
|
||||
// Lambda and Pi
|
||||
// Remark: we ignore get_abs_name because we want alpha-equivalence
|
||||
return eq(get_abs_type(a), get_abs_type(b)) && eq(get_abs_expr(a), get_abs_expr(b));
|
||||
case expr_kind::Prop: lean_unreachable(); return true;
|
||||
case expr_kind::Type:
|
||||
if (get_ty_num_vars(a) != get_ty_num_vars(b))
|
||||
return false;
|
||||
for (unsigned i = 0; i < get_ty_num_vars(a); i++) {
|
||||
uvar v1 = get_ty_var(a, i);
|
||||
uvar v2 = get_ty_var(b, i);
|
||||
if (v1.first != v2.first || v1.second != v2.second)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
case expr_kind::Numeral: return get_numeral(a) == get_numeral(b);
|
||||
}
|
||||
lean_unreachable();
|
||||
return false;
|
||||
}
|
||||
} // namespace expr_eq
|
||||
bool operator==(expr const & a, expr const & b) {
|
||||
expr_eq::g_eq_visited.clear();
|
||||
return expr_eq::eq(a, b);
|
||||
}
|
||||
|
||||
// Low-level pretty printer
|
||||
std::ostream & operator<<(std::ostream & out, expr const & a) {
|
||||
switch (a.kind()) {
|
||||
case expr_kind::Var:
|
||||
out << "#" << get_var_idx(a);
|
||||
break;
|
||||
case expr_kind::Constant:
|
||||
out << get_const_name(a);
|
||||
break;
|
||||
case expr_kind::Var: out << "#" << get_var_idx(a); break;
|
||||
case expr_kind::Constant: out << get_const_name(a); break;
|
||||
case expr_kind::App:
|
||||
out << "(";
|
||||
for (unsigned i = 0; i < get_num_args(a); i++) {
|
||||
|
@ -214,21 +179,11 @@ std::ostream & operator<<(std::ostream & out, expr const & a) {
|
|||
}
|
||||
out << ")";
|
||||
break;
|
||||
case expr_kind::Lambda:
|
||||
out << "(fun (" << get_abs_name(a) << " : " << get_abs_type(a) << ") " << get_abs_expr(a) << ")";
|
||||
break;
|
||||
case expr_kind::Pi:
|
||||
out << "(forall (" << get_abs_name(a) << " : " << get_abs_type(a) << ") " << get_abs_expr(a) << ")";
|
||||
break;
|
||||
case expr_kind::Prop:
|
||||
out << "Prop";
|
||||
break;
|
||||
case expr_kind::Type:
|
||||
out << "Type";
|
||||
break;
|
||||
case expr_kind::Numeral:
|
||||
out << get_numeral(a);
|
||||
break;
|
||||
case expr_kind::Lambda: out << "(fun (" << get_abs_name(a) << " : " << get_abs_type(a) << ") " << get_abs_expr(a) << ")"; break;
|
||||
case expr_kind::Pi: out << "(forall (" << get_abs_name(a) << " : " << get_abs_type(a) << ") " << get_abs_expr(a) << ")"; break;
|
||||
case expr_kind::Prop: out << "Prop"; break;
|
||||
case expr_kind::Type: out << "Type"; break;
|
||||
case expr_kind::Numeral: out << get_numeral(a); break;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
|
|
@ -146,6 +146,8 @@ public:
|
|||
~expr_app();
|
||||
unsigned get_num_args() const { return m_num_args; }
|
||||
expr const & get_arg(unsigned idx) const { lean_assert(idx < m_num_args); return m_args[idx]; }
|
||||
expr const * begin_args() const { return m_args; }
|
||||
expr const * end_args() const { return m_args + m_num_args; }
|
||||
};
|
||||
// 4. Abstraction
|
||||
class expr_abstraction : public expr_cell {
|
||||
|
@ -257,6 +259,7 @@ inline expr_numeral * to_numeral(expr const & e) { return to_numeral(e.r
|
|||
|
||||
// =======================================
|
||||
// Accessors
|
||||
inline unsigned get_rc(expr_cell * e) { return e->get_rc(); }
|
||||
inline unsigned get_var_idx(expr_cell * e) { return to_var(e)->get_vidx(); }
|
||||
inline name const & get_const_name(expr_cell * e) { return to_constant(e)->get_name(); }
|
||||
inline unsigned get_const_pos(expr_cell * e) { return to_constant(e)->get_pos(); }
|
||||
|
@ -269,11 +272,14 @@ inline unsigned get_ty_num_vars(expr_cell * e) { return to_type(e)-
|
|||
inline uvar const & get_ty_var(expr_cell * e, unsigned idx) { return to_type(e)->get_var(idx); }
|
||||
inline mpz const & get_numeral(expr_cell * e) { return to_numeral(e)->get_num(); }
|
||||
|
||||
inline unsigned get_rc(expr const & e) { return e.raw()->get_rc(); }
|
||||
inline unsigned get_var_idx(expr const & e) { return to_var(e)->get_vidx(); }
|
||||
inline name const & get_const_name(expr const & e) { return to_constant(e)->get_name(); }
|
||||
inline unsigned get_const_pos(expr const & e) { return to_constant(e)->get_pos(); }
|
||||
inline unsigned get_num_args(expr const & e) { return to_app(e)->get_num_args(); }
|
||||
inline expr const & get_arg(expr const & e, unsigned idx) { return to_app(e)->get_arg(idx); }
|
||||
inline expr const * begin_args(expr const & e) { return to_app(e)->begin_args(); }
|
||||
inline expr const * end_args(expr const & e) { return to_app(e)->end_args(); }
|
||||
inline name const & get_abs_name(expr const & e) { return to_abstraction(e)->get_name(); }
|
||||
inline expr const & get_abs_type(expr const & e) { return to_abstraction(e)->get_type(); }
|
||||
inline expr const & get_abs_expr(expr const & e) { return to_abstraction(e)->get_expr(); }
|
||||
|
@ -290,4 +296,18 @@ inline bool operator!=(expr const & a, expr const & b) { return !operator==(a, b
|
|||
|
||||
std::ostream & operator<<(std::ostream & out, expr const & a);
|
||||
|
||||
/**
|
||||
\brief Wrapper for iterating over application arguments.
|
||||
If n is an application, it allows us to write
|
||||
for (expr const & arg : app_args(n)) {
|
||||
... do something with argument
|
||||
}
|
||||
*/
|
||||
struct app_args {
|
||||
expr const & m_app;
|
||||
app_args(expr const & a):m_app(a) { lean_assert(is_app(a)); }
|
||||
expr const * begin() const { return &get_arg(m_app, 0); }
|
||||
expr const * end() const { return begin() + get_num_args(m_app); }
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ Author: Leonardo de Moura
|
|||
*/
|
||||
#include "expr.h"
|
||||
#include "test.h"
|
||||
#include <algorithm>
|
||||
using namespace lean;
|
||||
|
||||
static void tst1() {
|
||||
|
@ -34,9 +35,86 @@ expr mk_dag(unsigned depth) {
|
|||
return a;
|
||||
}
|
||||
|
||||
unsigned depth1(expr const & e) {
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral:
|
||||
return 1;
|
||||
case expr_kind::App: {
|
||||
unsigned m = 0;
|
||||
for (expr const & a : app_args(e))
|
||||
m = std::max(m, depth1(a));
|
||||
return m + 1;
|
||||
}
|
||||
case expr_kind::Lambda: case expr_kind::Pi:
|
||||
return std::max(depth1(get_abs_type(e)), depth1(get_abs_expr(e))) + 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// This is the fastest depth implementation in this file.
|
||||
unsigned depth2(expr const & e) {
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral:
|
||||
return 1;
|
||||
case expr_kind::App:
|
||||
return
|
||||
std::accumulate(begin_args(e), end_args(e), 0,
|
||||
[](unsigned m, expr const & arg){ return std::max(depth2(arg), m); })
|
||||
+ 1;
|
||||
case expr_kind::Lambda: case expr_kind::Pi:
|
||||
return std::max(depth2(get_abs_type(e)), depth2(get_abs_expr(e))) + 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// This is the slowest depth implementation in this file.
|
||||
unsigned depth3(expr const & e) {
|
||||
static std::vector<std::pair<expr const *, unsigned>> todo;
|
||||
unsigned m = 0;
|
||||
todo.push_back(std::make_pair(&e, 0));
|
||||
while (!todo.empty()) {
|
||||
auto const & p = todo.back();
|
||||
expr const & e = *(p.first);
|
||||
unsigned c = p.second + 1;
|
||||
todo.pop_back();
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral:
|
||||
m = std::max(c, m);
|
||||
break;
|
||||
case expr_kind::App: {
|
||||
unsigned num = get_num_args(e);
|
||||
for (unsigned i = 0; i < num; i++)
|
||||
todo.push_back(std::make_pair(&get_arg(e, i), c));
|
||||
break;
|
||||
}
|
||||
case expr_kind::Lambda: case expr_kind::Pi:
|
||||
todo.push_back(std::make_pair(&get_abs_type(e), c));
|
||||
todo.push_back(std::make_pair(&get_abs_expr(e), c));
|
||||
break;
|
||||
}
|
||||
}
|
||||
return m;
|
||||
}
|
||||
|
||||
static void tst2() {
|
||||
expr r1 = mk_dag(24);
|
||||
expr r2 = mk_dag(24);
|
||||
expr r1 = mk_dag(20);
|
||||
expr r2 = mk_dag(20);
|
||||
lean_verify(r1 == r2);
|
||||
std::cout << depth2(r1) << "\n";
|
||||
lean_verify(depth2(r1) == 21);
|
||||
}
|
||||
|
||||
expr mk_big(expr f, unsigned depth, unsigned val) {
|
||||
if (depth == 1)
|
||||
return var(val);
|
||||
else
|
||||
return app({f, mk_big(f, depth - 1, val << 1), mk_big(f, depth - 1, (val << 1) + 1)});
|
||||
}
|
||||
|
||||
static void tst3() {
|
||||
expr f = constant(name("f"));
|
||||
expr r1 = mk_big(f, 18, 0);
|
||||
expr r2 = mk_big(f, 18, 0);
|
||||
lean_verify(r1 == r2);
|
||||
}
|
||||
|
||||
|
@ -46,5 +124,6 @@ int main() {
|
|||
std::cout << "sizeof(expr_app): " << sizeof(expr_app) << "\n";
|
||||
tst1();
|
||||
tst2();
|
||||
tst3();
|
||||
return has_violations() ? 1 : 0;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue