From 2d93fe4b76344a8524af2a71fd375c1c84286b26 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 19 Nov 2015 19:37:11 -0800 Subject: [PATCH] feat(library/blast/congruence_closure): implement congruence closure Proof extraction is still missing --- src/library/blast/assert_cc_action.cpp | 6 +- src/library/blast/blast.cpp | 11 + src/library/blast/blast.h | 1 + src/library/blast/congruence_closure.cpp | 318 +++++++++++++++++++++-- src/library/blast/congruence_closure.h | 32 ++- src/library/blast/init_module.cpp | 3 + src/library/blast/simple_strategy.cpp | 2 +- 7 files changed, 343 insertions(+), 30 deletions(-) diff --git a/src/library/blast/assert_cc_action.cpp b/src/library/blast/assert_cc_action.cpp index 234894a62..41ae4f294 100644 --- a/src/library/blast/assert_cc_action.cpp +++ b/src/library/blast/assert_cc_action.cpp @@ -17,7 +17,8 @@ struct cc_branch_extension : public branch_extension { cc_branch_extension() {} cc_branch_extension(cc_branch_extension const & o):m_cc(o.m_cc) {} virtual ~cc_branch_extension() {} - virtual branch_extension * clone() { return new cc_branch_extension(*this); } + virtual branch_extension * clone() override { return new cc_branch_extension(*this); } + virtual void initialized() override { m_cc.initialize(); } }; void initialize_assert_cc_action() { @@ -33,6 +34,9 @@ static congruence_closure & get_cc() { action_result assert_cc_action(hypothesis_idx hidx) { congruence_closure & cc = get_cc(); cc.add(hidx); + // TODO(Leo): remove the following line + return action_result::new_branch(); + cc.display(); if (cc.is_inconsistent()) { try { app_builder & b = get_app_builder(); diff --git a/src/library/blast/blast.cpp b/src/library/blast/blast.cpp index 625e74680..69141e0f5 100644 --- a/src/library/blast/blast.cpp +++ b/src/library/blast/blast.cpp @@ -63,6 +63,7 @@ class blastenv { abstract_expr_manager m_abstract_expr_manager; relation_info_getter m_rel_getter; refl_info_getter m_refl_getter; + symm_info_getter m_symm_getter; class tctx : public type_context { blastenv & m_benv; @@ -442,6 +443,7 @@ public: m_abstract_expr_manager(m_fun_info_manager), m_rel_getter(mk_relation_info_getter(env)), m_refl_getter(mk_refl_info_getter(env)), + m_symm_getter(mk_symm_info_getter(env)), m_tctx(*this), m_normalizer(m_tctx) { init_uref_mref_href_idxs(); @@ -580,6 +582,10 @@ public: return static_cast(m_refl_getter(rop)); } + bool is_symmetric(name const & rop) const { + return static_cast(m_symm_getter(rop)); + } + optional get_relation_info(name const & rop) const { return m_rel_getter(rop); } @@ -643,6 +649,11 @@ bool is_reflexive(name const & rop) { return g_blastenv->is_reflexive(rop); } +bool is_symmetric(name const & rop) { + lean_assert(g_blastenv); + return g_blastenv->is_symmetric(rop); +} + optional get_relation_info(name const & rop) { lean_assert(g_blastenv); return g_blastenv->get_relation_info(rop); diff --git a/src/library/blast/blast.h b/src/library/blast/blast.h index 3e0e84c3e..48faf6d21 100644 --- a/src/library/blast/blast.h +++ b/src/library/blast/blast.h @@ -45,6 +45,7 @@ inline optional is_relation(expr const & R) { return is_constant(R) ? get_relation_info(const_name(R)) : optional(); } bool is_reflexive(name const & rop); +bool is_symmetric(name const & rop); /** \brief Put the given expression in weak-head-normal-form with respect to the current state being processed by the blast tactic. */ diff --git a/src/library/blast/congruence_closure.cpp b/src/library/blast/congruence_closure.cpp index 9b8b6875e..6c1cfaab1 100644 --- a/src/library/blast/congruence_closure.cpp +++ b/src/library/blast/congruence_closure.cpp @@ -83,6 +83,11 @@ scope_congruence_closure::~scope_congruence_closure() { g_congr_cache = static_cast(m_old_cache); } +void congruence_closure::initialize() { + mk_entry_for(get_iff_name(), mk_true()); + mk_entry_for(get_iff_name(), mk_false()); +} + void congruence_closure::mk_entry_for(name const & R, expr const & e) { lean_assert(!m_entries.find(eqc_key(R, e))); entry n; @@ -147,13 +152,224 @@ static optional mk_ext_congr_lemma(name const & R, expr const & } void congruence_closure::add_occurrence(name const & Rp, expr const & parent, name const & Rc, expr const & child) { - // TODO(Leo): - std::cout << Rp << parent << Rc << child << "\n"; + child_key k(Rc, child); + parent_occ_set ps; + if (auto old_ps = m_parents.find(k)) + ps = *old_ps; + ps.insert(parent_occ(Rp, parent)); + m_parents.insert(k, ps); +} + +/* Small hack for not storing a pointer to the congruence_closure object + at congruence_closure::congr_key_cmp */ +LEAN_THREAD_PTR(congruence_closure, g_cc); + +/* Auxiliary function for comparing (lhs1 ~ rhs1) and (lhs2 ~ rhs2), + when ~ is symmetric. + It returns 0 (equal) for (a ~ b) (b ~ a) */ +int congruence_closure::compare_symm(name const & R, expr lhs1, expr rhs1, expr lhs2, expr rhs2) const { + lhs1 = get_root(R, lhs1); + rhs1 = get_root(R, rhs1); + lhs2 = get_root(R, lhs2); + rhs2 = get_root(R, rhs2); + if (is_lt(lhs1, rhs1, true)) + std::swap(lhs1, rhs1); + if (is_lt(lhs2, rhs2, true)) + std::swap(lhs2, rhs2); + if (lhs1 != lhs2) + return is_lt(lhs1, lhs2, true) ? -1 : 1; + if (rhs1 != rhs2) + return is_lt(rhs1, rhs2, true) ? -1 : 1; + return 0; +} + +int congruence_closure::compare_root(name const & R, expr e1, expr e2) const { + e1 = get_root(R, e1); + e2 = get_root(R, e2); + return expr_quick_cmp()(e1, e2); +} + +int congruence_closure::congr_key_cmp::operator()(congr_key const & k1, congr_key const & k2) const { + if (k1.m_hash != k2.m_hash) + return unsigned_cmp()(k1.m_hash, k2.m_hash); + if (k1.m_R != k2.m_R) + return quick_cmp(k1.m_R, k2.m_R); + if (k1.m_eq != k2.m_eq) + return k1.m_eq ? -1 : 1; + if (k2.m_iff != k2.m_iff) + return k1.m_iff ? -1 : 1; + if (k2.m_symm_rel != k2.m_symm_rel) + return k1.m_symm_rel ? -1 : 1; + if (k1.m_eq || k1.m_iff) { + name const & R = k1.m_eq ? get_eq_name() : get_iff_name(); + expr const & lhs1 = app_arg(app_fn(k1.m_expr)); + expr const & rhs1 = app_arg(k1.m_expr); + expr const & lhs2 = app_arg(app_fn(k2.m_expr)); + expr const & rhs2 = app_arg(k2.m_expr); + return g_cc->compare_symm(R, lhs1, rhs1, lhs2, rhs2); + } else if (k1.m_symm_rel) { + name R1, R2; + expr lhs1, rhs1, lhs2, rhs2; + lean_verify(is_relation_app(k1.m_expr, R1, lhs1, rhs1)); + lean_verify(is_relation_app(k2.m_expr, R2, lhs2, rhs2)); + if (R1 != R2) + return quick_cmp(R1, R2); + return g_cc->compare_symm(R1, lhs1, rhs1, lhs2, rhs2); + } else { + lean_assert(!k1.m_eq && !k2.m_eq && !k1.m_iff && !k2.m_iff && + !k1.m_symm_rel && !k2.m_symm_rel); + lean_assert(k1.m_R == k2.m_R); + buffer args1, args2; + expr const & fn1 = get_app_args(k1.m_expr, args1); + expr const & fn2 = get_app_args(k2.m_expr, args2); + if (args1.size() != args2.size()) + return unsigned_cmp()(args1.size(), args2.size()); + auto lemma = mk_ext_congr_lemma(k1.m_R, fn1, args1.size()); + lean_assert(lemma); + if (!lemma->m_fixed_fun) { + int r = g_cc->compare_root(get_eq_name(), fn1, fn2); + if (r != 0) return r; + for (unsigned i = 0; i < args1.size(); i++) { + r = g_cc->compare_root(get_eq_name(), args1[i], args2[i]); + if (r != 0) return r; + } + return 0; + } else { + list> const * it1 = &lemma->m_rel_names; + list const * it2 = &lemma->m_congr_lemma.get_arg_kinds(); + int r; + for (unsigned i = 0; i < args1.size(); i++) { + lean_assert(*it1); lean_assert(*it2); + switch (head(*it2)) { + case congr_arg_kind::Eq: + lean_assert(head(*it1)); + r = g_cc->compare_root(*head(*it1), args1[i], args2[i]); + if (r != 0) return r; + break; + case congr_arg_kind::Fixed: + r = expr_quick_cmp()(args1[i], args2[i]); + if (r != 0) return r; + break; + case congr_arg_kind::Cast: + // do nothing... ignore argument + break; + } + it1 = &(tail(*it1)); + it2 = &(tail(*it2)); + } + return 0; + } + } +} + +unsigned congruence_closure::symm_hash(name const & R, expr const & lhs, expr const & rhs) const { + unsigned h1 = get_root(R, lhs).hash(); + unsigned h2 = get_root(R, rhs).hash(); + if (h1 > h2) + std::swap(h1, h2); + return (h1 << 16) | (h2 & 0xFFFF); +} + +auto congruence_closure::mk_congr_key(ext_congr_lemma const & lemma, expr const & e) const -> congr_key { + congr_key k; + k.m_R = lemma.m_R; + k.m_expr = e; + lean_assert(is_app(e)); + name R; expr lhs, rhs; + if (is_eq(e, lhs, rhs)) { + k.m_eq = true; + k.m_hash = symm_hash(get_eq_name(), lhs, rhs); + } else if (is_iff(e, lhs, rhs)) { + k.m_iff = true; + k.m_hash = symm_hash(get_iff_name(), lhs, rhs); + } else if (is_relation_app(e, R, lhs, rhs) && is_symmetric(R)) { + k.m_symm_rel = true; + k.m_hash = symm_hash(R, lhs, rhs); + } else { + buffer args; + expr const & fn = get_app_args(e, args); + if (!lemma.m_fixed_fun) { + unsigned h = get_root(get_eq_name(), fn).hash(); + for (unsigned i = 0; i < args.size(); i++) { + h = hash(h, get_root(get_eq_name(), args[i]).hash()); + } + k.m_hash = h; + } else { + unsigned h = fn.hash(); + list> const * it1 = &lemma.m_rel_names; + list const * it2 = &lemma.m_congr_lemma.get_arg_kinds(); + for (unsigned i = 0; i < args.size(); i++) { + lean_assert(*it1); lean_assert(*it2); + switch (head(*it2)) { + case congr_arg_kind::Eq: + lean_assert(head(*it1)); + h = hash(h, get_root(*head(*it1), args[i]).hash()); + break; + case congr_arg_kind::Fixed: + h = hash(h, args[i].hash()); + break; + case congr_arg_kind::Cast: + // do nothing... ignore argument + break; + } + it1 = &(tail(*it1)); + it2 = &(tail(*it2)); + } + k.m_hash = h; + } + } + return k; +} + +static expr * g_congr_mark = nullptr; // dummy congruence proof, it is just a placeholder. +static expr * g_iff_true_mark = nullptr; // dummy iff_true proof, it is just a placeholder. + +typedef std::tuple cc_todo_entry; + +MK_THREAD_LOCAL_GET_DEF(std::vector, get_todo); + +static void clear_todo() { + get_todo().clear(); +} + +void congruence_closure::check_iff_true(congr_key const & k) { + expr const & e = k.m_expr; + name R; expr lhs, rhs; + if (k.m_eq || k.m_iff) { + R = k.m_eq ? get_eq_name() : get_iff_name(); + lhs = app_arg(app_fn(e)); + rhs = app_arg(e); + } else if (k.m_symm_rel) { + lean_verify(is_relation_app(e, R, lhs, rhs)); + } else { + return; + } + if (is_eqv(get_iff_name(), e, mk_true())) + return; // it is already equivalent to true + lhs = get_root(R, lhs); + rhs = get_root(R, rhs); + if (lhs != rhs) + return; + // Add e <-> true + get_todo().emplace_back(get_iff_name(), e, mk_true(), *g_iff_true_mark); } void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, expr const & e) { - // TODO(Leo): - std::cout << lemma.m_R << e << "\n"; + lean_assert(is_app(e)); + congr_key k = mk_congr_key(lemma, e); + if (auto old_k = m_congruences.find(k)) { + // Found new equivalence: e ~ old_k->m_expr + // 1. Update m_cg_root field for e + eqc_key k(lemma.m_R, e); + entry new_entry = *m_entries.find(k); + new_entry.m_cg_root = old_k->m_expr; + m_entries.insert(k, new_entry); + // 2. Put new equivalence in the TODO queue + get_todo().emplace_back(lemma.m_R, e, old_k->m_expr, *g_congr_mark); + } else { + m_congruences.insert(k); + } + check_iff_true(k); } void congruence_closure::internalize(name const & R, expr const & e) { @@ -216,14 +432,6 @@ void congruence_closure::internalize(expr const & e) { internalize(get_eq_name(), e); } -typedef std::tuple cc_todo_entry; - -MK_THREAD_LOCAL_GET_DEF(std::vector, get_todo); - -static void clear_todo() { - get_todo().clear(); -} - /* The fields m_target and m_proof in e's entry are encoding a transitivity proof Let target(e) and proof(e) denote these fields. @@ -253,13 +461,28 @@ void congruence_closure::invert_trans(name const & R, expr const & e) { } void congruence_closure::remove_parents(name const & R, expr const & e) { - std::cout << R << " " << e << "\n"; - // TODO(Leo): + auto ps = m_parents.find(child_key(R, e)); + if (!ps) return; + ps->for_each([&](parent_occ const & p) { + expr const & fn = get_app_fn(p.m_expr); + unsigned nargs = get_app_num_args(p.m_expr); + auto lemma = mk_ext_congr_lemma(p.m_R, fn, nargs); + lean_assert(lemma); + congr_key k = mk_congr_key(*lemma, p.m_expr); + m_congruences.erase(k); + }); } -void congruence_closure::insert_parents(name const & R, expr const & e) { - std::cout << R << " " << e << "\n"; - // TODO(Leo): +void congruence_closure::reinsert_parents(name const & R, expr const & e) { + auto ps = m_parents.find(child_key(R, e)); + if (!ps) return; + ps->for_each([&](parent_occ const & p) { + expr const & fn = get_app_fn(p.m_expr); + unsigned nargs = get_app_num_args(p.m_expr); + auto lemma = mk_ext_congr_lemma(p.m_R, fn, nargs); + lean_assert(lemma); + add_congruence_table(*lemma, p.m_expr); + }); } void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr const & H) { @@ -310,7 +533,7 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con it = new_it_n.m_next; } while (it != e1); - insert_parents(R, e1); + reinsert_parents(R, e1); // update next of e1_root and e2_root, and size of e2_root r1 = m_entries.find(eqc_key(R, e1_root)); @@ -325,6 +548,20 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con m_entries.insert(eqc_key(R, e1_root), new_r1); m_entries.insert(eqc_key(R, e2_root), new_r2); lean_assert(check_invariant()); + + // copy e1_root parents to e2_root + child_key k1(R, e1_root); + auto ps1 = m_parents.find(k1); + if (!ps1) return; // e1_root doesn't have parents + parent_occ_set ps2; + child_key k2(R, e2_root); + if (auto it = m_parents.find(k2)) + ps2 = *it; + ps1->for_each([&](parent_occ const & p) { + ps2.insert(p); + }); + m_parents.erase(k1); + m_parents.insert(k2, ps2); } void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const & _rhs, expr const & _H) { @@ -341,6 +578,7 @@ void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const void congruence_closure::add(hypothesis_idx hidx) { if (is_inconsistent()) return; + flet set_cc(g_cc, this); clear_todo(); state & s = curr_state(); app_builder & b = get_app_builder(); @@ -479,7 +717,7 @@ expr congruence_closure::get_root(name const & R, expr const & e) const { if (auto n = m_entries.find(eqc_key(R, e))) { return n->m_root; } else { - return e;; + return e; } } @@ -487,7 +725,7 @@ expr congruence_closure::get_next(name const & R, expr const & e) const { if (auto n = m_entries.find(eqc_key(R, e))) { return n->m_next; } else { - return e;; + return e; } } @@ -505,7 +743,7 @@ void congruence_closure::display_eqc(name const & R, expr const & e) const { out << "}"; } -void congruence_closure::display() const { +void congruence_closure::display_eqcs() const { auto out = diagnostic(env(), ios()); m_entries.for_each([&](eqc_key const & k, entry const & n) { if (k.m_expr == n.m_root) { @@ -515,6 +753,33 @@ void congruence_closure::display() const { }); } +static void display_rel(io_state_stream & out, name const & R) { + if (R != get_eq_name()) + out << "[" << R << "] "; +} + +void congruence_closure::display_parents() const { + auto out = diagnostic(env(), ios()); + m_parents.for_each([&](child_key const & k, parent_occ_set const & ps) { + display_rel(out, k.m_R); + out << ppb(k.m_expr); + out << ", parents: {"; + bool first = true; + ps.for_each([&](parent_occ const & o) { + if (first) first = false; else out << ", "; + display_rel(out, o.m_R); + out << ppb(o.m_expr); + }); + out << "}\n"; + }); +} + +void congruence_closure::display() const { + diagnostic(env(), ios()) << "congruence closure state\n"; + display_eqcs(); + display_parents(); +} + bool congruence_closure::check_eqc(name const & R, expr const & e) const { expr root = get_root(R, e); unsigned size = 0; @@ -548,4 +813,15 @@ bool congruence_closure::check_invariant() const { // TODO(Leo): return true; } + +void initialize_congruence_closure() { + name prefix = name::mk_internal_unique_name(); + g_congr_mark = new expr(mk_constant(name(prefix, "[congruence]"))); + g_iff_true_mark = new expr(mk_constant(name(prefix, "[iff-true]"))); +} + +void finalize_congruence_closure() { + delete g_congr_mark; + delete g_iff_true_mark; +} }} diff --git a/src/library/blast/congruence_closure.h b/src/library/blast/congruence_closure.h index eee7c2eb6..736840a5d 100644 --- a/src/library/blast/congruence_closure.h +++ b/src/library/blast/congruence_closure.h @@ -76,34 +76,47 @@ class congruence_closure { unsigned m_eq:1; // true if m_expr is an equality unsigned m_iff:1; // true if m_expr is an iff unsigned m_symm_rel:1; // true if m_expr is another symmetric relation. + congr_key() { m_eq = 0; m_iff = 0; m_symm_rel = 0; } }; struct congr_key_cmp { - int operator()(congr_key const & k1, congr_key const & k2); + int operator()(congr_key const & k1, congr_key const & k2) const; }; - typedef rb_tree expr_set; - typedef rb_map entries; - // TODO(Leo): fix and take relation into account - typedef rb_map parents; - typedef rb_tree congruences; + typedef rb_tree expr_set; + typedef rb_map entries; + typedef eqc_key child_key; + typedef eqc_key_cmp child_key_cmp; + typedef eqc_key parent_occ; + typedef eqc_key_cmp parent_occ_cmp; + typedef rb_tree parent_occ_set; + typedef rb_map parents; + typedef rb_tree congruences; entries m_entries; parents m_parents; congruences m_congruences; + int compare_symm(name const & R, expr lhs1, expr rhs1, expr lhs2, expr rhs2) const; + int compare_root(name const & R, expr e1, expr e2) const; + unsigned symm_hash(name const & R, expr const & lhs, expr const & rhs) const; + congr_key mk_congr_key(ext_congr_lemma const & lemma, expr const & e) const; + void check_iff_true(congr_key const & k); + void mk_entry_for(name const & R, expr const & e); void add_occurrence(name const & Rp, expr const & parent, name const & Rc, expr const & child); void add_congruence_table(ext_congr_lemma const & lemma, expr const & e); void invert_trans(name const & R, expr const & e, optional new_target, optional new_proof); void invert_trans(name const & R, expr const & e); void remove_parents(name const & R, expr const & e); - void insert_parents(name const & R, expr const & e); + void reinsert_parents(name const & R, expr const & e); void add_eqv_step(name const & R, expr e1, expr e2, expr const & H); void add_eqv(name const & R, expr const & lhs, expr const & rhs, expr const & H); void display_eqc(name const & R, expr const & e) const; public: + void initialize(); + /** \brief Register expression \c e in this data-structure. It creates entries for each sub-expression in \c e. It also updates the m_parents mapping. @@ -160,6 +173,8 @@ public: /** \brief dump for debugging purposes. */ void display() const; + void display_eqcs() const; + void display_parents() const; bool check_eqc(name const & R, expr const & e) const; bool check_invariant() const; }; @@ -171,4 +186,7 @@ public: scope_congruence_closure(); ~scope_congruence_closure(); }; + +void initialize_congruence_closure(); +void finalize_congruence_closure(); }} diff --git a/src/library/blast/init_module.cpp b/src/library/blast/init_module.cpp index f62e0abfc..60127b436 100644 --- a/src/library/blast/init_module.cpp +++ b/src/library/blast/init_module.cpp @@ -10,6 +10,7 @@ Author: Leonardo de Moura #include "library/blast/blast_tactic.h" #include "library/blast/simplifier.h" #include "library/blast/options.h" +#include "library/blast/congruence_closure.h" #include "library/blast/recursor_action.h" #include "library/blast/assert_cc_action.h" #include "library/blast/backward/init_module.h" @@ -27,8 +28,10 @@ void initialize_blast_module() { initialize_blast_tactic(); blast::initialize_recursor_action(); blast::initialize_assert_cc_action(); + blast::initialize_congruence_closure(); } void finalize_blast_module() { + blast::finalize_congruence_closure(); blast::finalize_assert_cc_action(); blast::finalize_recursor_action(); finalize_blast_tactic(); diff --git a/src/library/blast/simple_strategy.cpp b/src/library/blast/simple_strategy.cpp index a5234bc44..768b3a1cf 100644 --- a/src/library/blast/simple_strategy.cpp +++ b/src/library/blast/simple_strategy.cpp @@ -34,12 +34,12 @@ class simple_strategy : public strategy { trace_action("activate"); Try(assumption_contradiction_actions(*hidx)); + TrySolve(assert_cc_action(*hidx)); Try(subst_action(*hidx)); Try(no_confusion_action(*hidx)); Try(discard_action(*hidx)); Try(forward_action(*hidx)); Try(recursor_preprocess_action(*hidx)); - TrySolve(assert_cc_action(*hidx)); return action_result::new_branch(); }