feat(expr_lt): improve expr_lt performance by using hash codes, and add more tests

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-09-25 21:59:58 -07:00
parent 6477708d78
commit db4e5ab0ad
3 changed files with 47 additions and 27 deletions

View file

@ -7,13 +7,14 @@ Author: Leonardo de Moura
#include "kernel/expr.h" #include "kernel/expr.h"
namespace lean { namespace lean {
bool is_lt(expr const & a, expr const & b) { bool is_lt(expr const & a, expr const & b, bool use_hash) {
if (is_eqp(a, b)) return false; if (is_eqp(a, b)) return false;
if (!a && b) return true; // the null expression is the smallest one if (!a && b) return true; // the null expression is the smallest one
if (a && !b) return false; if (a && !b) return false;
if (a.kind() != b.kind()) return a.kind() < b.kind(); if (a.kind() != b.kind()) return a.kind() < b.kind();
if (a == b) return false; if (use_hash && a.hash() < b.hash()) return true;
if (is_var(a)) return var_idx(a) < var_idx(b); if (a == b) return false;
if (is_var(a)) return var_idx(a) < var_idx(b);
switch (a.kind()) { switch (a.kind()) {
case expr_kind::Var: case expr_kind::Var:
lean_unreachable(); lean_unreachable();
@ -24,31 +25,31 @@ bool is_lt(expr const & a, expr const & b) {
return num_args(a) < num_args(b); return num_args(a) < num_args(b);
for (unsigned i = 0; i < num_args(a); i++) { for (unsigned i = 0; i < num_args(a); i++) {
if (arg(a, i) != arg(b, i)) if (arg(a, i) != arg(b, i))
return is_lt(arg(a, i), arg(b, i)); return is_lt(arg(a, i), arg(b, i), use_hash);
} }
return false; lean_unreachable();
case expr_kind::Eq: case expr_kind::Eq:
if (eq_lhs(a) != eq_lhs(b)) if (eq_lhs(a) != eq_lhs(b))
return is_lt(eq_lhs(a), eq_lhs(b)); return is_lt(eq_lhs(a), eq_lhs(b), use_hash);
else else
return is_lt(eq_rhs(a), eq_rhs(b)); return is_lt(eq_rhs(a), eq_rhs(b), use_hash);
case expr_kind::Lambda: // Remark: we ignore get_abs_name because we want alpha-equivalence case expr_kind::Lambda: // Remark: we ignore get_abs_name because we want alpha-equivalence
case expr_kind::Pi: case expr_kind::Pi:
if (abst_domain(a) != abst_domain(b)) if (abst_domain(a) != abst_domain(b))
return is_lt(abst_domain(a), abst_domain(b)); return is_lt(abst_domain(a), abst_domain(b), use_hash);
else else
return is_lt(abst_body(a), abst_body(b)); return is_lt(abst_body(a), abst_body(b), use_hash);
case expr_kind::Type: case expr_kind::Type:
return ty_level(a) < ty_level(b); return ty_level(a) < ty_level(b);
case expr_kind::Value: case expr_kind::Value:
return to_value(a) < to_value(b); return to_value(a) < to_value(b);
case expr_kind::Let: case expr_kind::Let:
if (let_type(a) != let_type(b)) { if (let_type(a) != let_type(b)) {
return is_lt(let_type(a), let_type(b)); return is_lt(let_type(a), let_type(b), use_hash);
} else if (let_value(a) != let_value(b)){ } else if (let_value(a) != let_value(b)){
return is_lt(let_value(a), let_value(b)); return is_lt(let_value(a), let_value(b), use_hash);
} else { } else {
return is_lt(let_body(a), let_body(b)); return is_lt(let_body(a), let_body(b), use_hash);
} }
case expr_kind::MetaVar: case expr_kind::MetaVar:
if (metavar_idx(a) != metavar_idx(b)) { if (metavar_idx(a) != metavar_idx(b)) {
@ -65,13 +66,13 @@ bool is_lt(expr const & a, expr const & b) {
return it1->s() < it2->s(); return it1->s() < it2->s();
} else if (it1->is_inst()) { } else if (it1->is_inst()) {
if (it1->v() != it2->v()) if (it1->v() != it2->v())
return is_lt(it1->v(), it2->v()); return is_lt(it1->v(), it2->v(), use_hash);
} else { } else {
if (it1->n() != it2->n()) if (it1->n() != it2->n())
return it1->n() < it2->n(); return it1->n() < it2->n();
} }
} }
return false; return it1 == end1 && it2 != end2;
} }
} }
lean_unreachable(); lean_unreachable();

View file

@ -7,10 +7,16 @@ Author: Leonardo de Moura
#pragma once #pragma once
#include "kernel/expr.h" #include "kernel/expr.h"
namespace lean { namespace lean {
/** \brief Total order on expressions */ /**
bool is_lt(expr const & a, expr const & b); \brief Total order on expressions.
inline bool operator<(expr const & a, expr const & b) { return is_lt(a, b); }
inline bool operator>(expr const & a, expr const & b) { return is_lt(b, a); } \remark If \c use_hash is true, then we use the hash_code to
inline bool operator<=(expr const & a, expr const & b) { return !is_lt(b, a); } partially order expressions. Setting use_hash to false is useful
inline bool operator>=(expr const & a, expr const & b) { return !is_lt(a, b); } for testing the code.
*/
bool is_lt(expr const & a, expr const & b, bool use_hash);
inline bool operator<(expr const & a, expr const & b) { return is_lt(a, b, true); }
inline bool operator>(expr const & a, expr const & b) { return is_lt(b, a, true); }
inline bool operator<=(expr const & a, expr const & b) { return !is_lt(b, a, true); }
inline bool operator>=(expr const & a, expr const & b) { return !is_lt(a, b, true); }
} }

View file

@ -14,9 +14,9 @@ Author: Leonardo de Moura
using namespace lean; using namespace lean;
static void lt(expr const & e1, expr const & e2, bool expected) { static void lt(expr const & e1, expr const & e2, bool expected) {
lean_assert((e1 < e2) == expected); lean_assert(is_lt(e1, e2, false) == expected);
lean_assert((e1 < e2) == !(e1 == e2 || e1 > e2)); lean_assert(is_lt(e1, e2, false) == !(e1 == e2 || (is_lt(e2, e1, false))));
lean_assert((e1 < e2) == (e2 > e1)); lean_assert(!(e1.hash() < e2.hash()) || (e1 < e2))
} }
static void tst1() { static void tst1() {
@ -39,6 +39,8 @@ static void tst1() {
lt(Eq(Var(1), Var(0)), Eq(Var(1), Var(1)), true); lt(Eq(Var(1), Var(0)), Eq(Var(1), Var(1)), true);
lt(Eq(Var(1), Var(1)), Eq(Var(1), Var(1)), false); lt(Eq(Var(1), Var(1)), Eq(Var(1), Var(1)), false);
lt(Eq(Var(2), Var(1)), Eq(Var(1), Var(1)), false); lt(Eq(Var(2), Var(1)), Eq(Var(1), Var(1)), false);
lt(Const("f")(Var(0)), Const("f")(Var(0), Const("a")), true);
lt(Const("f")(Var(0), Const("a"), Const("b")), Const("f")(Var(0), Const("a")), false);
lt(Const("f")(Var(0), Const("a")), Const("g")(Var(0), Const("a")), true); lt(Const("f")(Var(0), Const("a")), Const("g")(Var(0), Const("a")), true);
lt(Const("f")(Var(0), Const("a")), Const("f")(Var(1), Const("a")), true); lt(Const("f")(Var(0), Const("a")), Const("f")(Var(1), Const("a")), true);
lt(Const("f")(Var(0), Const("a")), Const("f")(Var(0), Const("b")), true); lt(Const("f")(Var(0), Const("a")), Const("f")(Var(0), Const("b")), true);
@ -63,6 +65,17 @@ static void tst1() {
lt(mk_pi("x", Int, Int), mk_pi("y", Real, Bool), true); lt(mk_pi("x", Int, Int), mk_pi("y", Real, Bool), true);
lt(mk_pi("x", Int, Int), mk_pi("y", Int, Real), true); lt(mk_pi("x", Int, Int), mk_pi("y", Int, Real), true);
lt(mk_pi("x", Int, Int), mk_pi("y", Int, Int), false); lt(mk_pi("x", Int, Int), mk_pi("y", Int, Int), false);
meta_ctx ctx1{mk_lift(0, 1), mk_inst(0, Const("a"))};
meta_ctx ctx2{mk_lift(0, 1), mk_inst(0, Const("b"))};
meta_ctx ctx3{mk_lift(3, 1), mk_inst(0, Const("a"))};
meta_ctx ctx4{mk_lift(0, 1), mk_inst(0, Const("a")), mk_inst(0, Const("b"))};
meta_ctx ctx5{mk_inst(0, Const("a")), mk_inst(0, Const("a"))};
lt(mk_metavar(0, ctx1), mk_metavar(1, ctx1), true);
lt(mk_metavar(0, ctx1), mk_metavar(0, ctx2), true);
lt(mk_metavar(0, ctx1), mk_metavar(0, ctx3), true);
lt(mk_metavar(0, ctx1), mk_metavar(0, ctx4), true);
lt(mk_metavar(0, ctx1), mk_metavar(0, ctx5), true);
lt(mk_metavar(0, ctx1), mk_metavar(0, ctx1), false);
} }
int main() { int main() {