diff --git a/src/library/expr_lt.cpp b/src/library/expr_lt.cpp index 8fc8a3eed..3bcfacf1a 100644 --- a/src/library/expr_lt.cpp +++ b/src/library/expr_lt.cpp @@ -7,13 +7,14 @@ Author: Leonardo de Moura #include "kernel/expr.h" namespace lean { -bool is_lt(expr const & a, expr const & b) { - if (is_eqp(a, b)) return false; - if (!a && b) return true; // the null expression is the smallest one - if (a && !b) return false; - if (a.kind() != b.kind()) return a.kind() < b.kind(); - if (a == b) return false; - if (is_var(a)) return var_idx(a) < var_idx(b); +bool is_lt(expr const & a, expr const & b, bool use_hash) { + if (is_eqp(a, b)) return false; + if (!a && b) return true; // the null expression is the smallest one + if (a && !b) return false; + if (a.kind() != b.kind()) return a.kind() < b.kind(); + if (use_hash && a.hash() < b.hash()) return true; + if (a == b) return false; + if (is_var(a)) return var_idx(a) < var_idx(b); switch (a.kind()) { case expr_kind::Var: lean_unreachable(); @@ -24,31 +25,31 @@ bool is_lt(expr const & a, expr const & b) { return num_args(a) < num_args(b); for (unsigned i = 0; i < num_args(a); i++) { if (arg(a, i) != arg(b, i)) - return is_lt(arg(a, i), arg(b, i)); + return is_lt(arg(a, i), arg(b, i), use_hash); } - return false; + lean_unreachable(); case expr_kind::Eq: if (eq_lhs(a) != eq_lhs(b)) - return is_lt(eq_lhs(a), eq_lhs(b)); + return is_lt(eq_lhs(a), eq_lhs(b), use_hash); else - return is_lt(eq_rhs(a), eq_rhs(b)); + return is_lt(eq_rhs(a), eq_rhs(b), use_hash); case expr_kind::Lambda: // Remark: we ignore get_abs_name because we want alpha-equivalence case expr_kind::Pi: if (abst_domain(a) != abst_domain(b)) - return is_lt(abst_domain(a), abst_domain(b)); + return is_lt(abst_domain(a), abst_domain(b), use_hash); else - return is_lt(abst_body(a), abst_body(b)); + return is_lt(abst_body(a), abst_body(b), use_hash); case expr_kind::Type: return ty_level(a) < ty_level(b); case expr_kind::Value: return to_value(a) < to_value(b); case expr_kind::Let: if (let_type(a) != let_type(b)) { - return is_lt(let_type(a), let_type(b)); + return is_lt(let_type(a), let_type(b), use_hash); } else if (let_value(a) != let_value(b)){ - return is_lt(let_value(a), let_value(b)); + return is_lt(let_value(a), let_value(b), use_hash); } else { - return is_lt(let_body(a), let_body(b)); + return is_lt(let_body(a), let_body(b), use_hash); } case expr_kind::MetaVar: if (metavar_idx(a) != metavar_idx(b)) { @@ -65,13 +66,13 @@ bool is_lt(expr const & a, expr const & b) { return it1->s() < it2->s(); } else if (it1->is_inst()) { if (it1->v() != it2->v()) - return is_lt(it1->v(), it2->v()); + return is_lt(it1->v(), it2->v(), use_hash); } else { if (it1->n() != it2->n()) return it1->n() < it2->n(); } } - return false; + return it1 == end1 && it2 != end2; } } lean_unreachable(); diff --git a/src/library/expr_lt.h b/src/library/expr_lt.h index 19cc331b1..cbd70b437 100644 --- a/src/library/expr_lt.h +++ b/src/library/expr_lt.h @@ -7,10 +7,16 @@ Author: Leonardo de Moura #pragma once #include "kernel/expr.h" namespace lean { -/** \brief Total order on expressions */ -bool is_lt(expr const & a, expr const & b); -inline bool operator<(expr const & a, expr const & b) { return is_lt(a, b); } -inline bool operator>(expr const & a, expr const & b) { return is_lt(b, a); } -inline bool operator<=(expr const & a, expr const & b) { return !is_lt(b, a); } -inline bool operator>=(expr const & a, expr const & b) { return !is_lt(a, b); } +/** + \brief Total order on expressions. + + \remark If \c use_hash is true, then we use the hash_code to + partially order expressions. Setting use_hash to false is useful + for testing the code. +*/ +bool is_lt(expr const & a, expr const & b, bool use_hash); +inline bool operator<(expr const & a, expr const & b) { return is_lt(a, b, true); } +inline bool operator>(expr const & a, expr const & b) { return is_lt(b, a, true); } +inline bool operator<=(expr const & a, expr const & b) { return !is_lt(b, a, true); } +inline bool operator>=(expr const & a, expr const & b) { return !is_lt(a, b, true); } } diff --git a/src/tests/library/expr_lt.cpp b/src/tests/library/expr_lt.cpp index 0e023b183..c620d083f 100644 --- a/src/tests/library/expr_lt.cpp +++ b/src/tests/library/expr_lt.cpp @@ -14,9 +14,9 @@ Author: Leonardo de Moura using namespace lean; static void lt(expr const & e1, expr const & e2, bool expected) { - lean_assert((e1 < e2) == expected); - lean_assert((e1 < e2) == !(e1 == e2 || e1 > e2)); - lean_assert((e1 < e2) == (e2 > e1)); + lean_assert(is_lt(e1, e2, false) == expected); + lean_assert(is_lt(e1, e2, false) == !(e1 == e2 || (is_lt(e2, e1, false)))); + lean_assert(!(e1.hash() < e2.hash()) || (e1 < e2)) } static void tst1() { @@ -39,6 +39,8 @@ static void tst1() { lt(Eq(Var(1), Var(0)), Eq(Var(1), Var(1)), true); lt(Eq(Var(1), Var(1)), Eq(Var(1), Var(1)), false); lt(Eq(Var(2), Var(1)), Eq(Var(1), Var(1)), false); + lt(Const("f")(Var(0)), Const("f")(Var(0), Const("a")), true); + lt(Const("f")(Var(0), Const("a"), Const("b")), Const("f")(Var(0), Const("a")), false); lt(Const("f")(Var(0), Const("a")), Const("g")(Var(0), Const("a")), true); lt(Const("f")(Var(0), Const("a")), Const("f")(Var(1), Const("a")), true); lt(Const("f")(Var(0), Const("a")), Const("f")(Var(0), Const("b")), true); @@ -63,6 +65,17 @@ static void tst1() { lt(mk_pi("x", Int, Int), mk_pi("y", Real, Bool), true); lt(mk_pi("x", Int, Int), mk_pi("y", Int, Real), true); lt(mk_pi("x", Int, Int), mk_pi("y", Int, Int), false); + meta_ctx ctx1{mk_lift(0, 1), mk_inst(0, Const("a"))}; + meta_ctx ctx2{mk_lift(0, 1), mk_inst(0, Const("b"))}; + meta_ctx ctx3{mk_lift(3, 1), mk_inst(0, Const("a"))}; + meta_ctx ctx4{mk_lift(0, 1), mk_inst(0, Const("a")), mk_inst(0, Const("b"))}; + meta_ctx ctx5{mk_inst(0, Const("a")), mk_inst(0, Const("a"))}; + lt(mk_metavar(0, ctx1), mk_metavar(1, ctx1), true); + lt(mk_metavar(0, ctx1), mk_metavar(0, ctx2), true); + lt(mk_metavar(0, ctx1), mk_metavar(0, ctx3), true); + lt(mk_metavar(0, ctx1), mk_metavar(0, ctx4), true); + lt(mk_metavar(0, ctx1), mk_metavar(0, ctx5), true); + lt(mk_metavar(0, ctx1), mk_metavar(0, ctx1), false); } int main() {