From 968b6153904ba63ec60b6c38ee250eb22e5acc45 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 19 Nov 2015 15:55:56 -0800 Subject: [PATCH] feat(library/blast/congruence_closure): add "add_eqv" --- src/library/blast/congruence_closure.cpp | 171 +++++++++++++++++++++-- src/library/blast/congruence_closure.h | 13 +- 2 files changed, 173 insertions(+), 11 deletions(-) diff --git a/src/library/blast/congruence_closure.cpp b/src/library/blast/congruence_closure.cpp index 86d4eff20..9b8b6875e 100644 --- a/src/library/blast/congruence_closure.cpp +++ b/src/library/blast/congruence_closure.cpp @@ -10,6 +10,7 @@ Author: Leonardo de Moura #include "library/blast/congruence_closure.h" #include "library/blast/util.h" #include "library/blast/blast.h" +#include "library/blast/trace.h" namespace lean { namespace blast { @@ -88,7 +89,7 @@ void congruence_closure::mk_entry_for(name const & R, expr const & e) { n.m_next = e; n.m_root = e; n.m_cg_root = e; - n.m_rank = 0; + n.m_size = 1; m_entries.insert(eqc_key(R, e), n); } @@ -159,6 +160,8 @@ void congruence_closure::internalize(name const & R, expr const & e) { lean_assert(closed(e)); if (has_expr_metavar(e)) return; + if (m_entries.find(eqc_key(R, e))) + return; // e has already been internalized switch (e.kind()) { case expr_kind::Var: case expr_kind::Meta: lean_unreachable(); @@ -221,14 +224,117 @@ static void clear_todo() { get_todo().clear(); } -void congruence_closure::add_eqv(name const & R, expr const & lhs, expr const & rhs, expr const & pr) { +/* + The fields m_target and m_proof in e's entry are encoding a transitivity proof + Let target(e) and proof(e) denote these fields. + + e = target(e) : proof(e) + ... = target(target(e)) : proof(target(e)) + ... ... + = root(e) : ... + + The transitivity proof eventually reaches the root of the equivalence class. + This method "inverts" the proof. That is, the m_target goes from root(e) to e after + we execute it. +*/ +void congruence_closure::invert_trans(name const & R, expr const & e, optional new_target, optional new_proof) { + eqc_key k(R, e); + auto n = m_entries.find(k); + lean_assert(n); + entry new_n = *n; + if (n->m_target) + invert_trans(R, *new_n.m_target, some_expr(e), new_n.m_proof); + new_n.m_target = new_target; + new_n.m_proof = new_proof; + m_entries.insert(k, new_n); +} +void congruence_closure::invert_trans(name const & R, expr const & e) { + invert_trans(R, e, none_expr(), none_expr()); +} + +void congruence_closure::remove_parents(name const & R, expr const & e) { + std::cout << R << " " << e << "\n"; + // TODO(Leo): +} + +void congruence_closure::insert_parents(name const & R, expr const & e) { + std::cout << R << " " << e << "\n"; + // TODO(Leo): +} + +void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr const & H) { + auto n1 = m_entries.find(eqc_key(R, e1)); + auto n2 = m_entries.find(eqc_key(R, e2)); + if (!n1 || !n2) + return; + if (n1->m_root == n2->m_root) + return; // they are already in the same equivalence class + auto r1 = m_entries.find(eqc_key(R, n1->m_root)); + auto r2 = m_entries.find(eqc_key(R, n2->m_root)); + lean_assert(r1 && r2); + + // We want r2 to be the root of the combined class. + + if (r1->m_size > r2->m_size) { + std::swap(e1, e2); + std::swap(n1, n2); + std::swap(r1, r2); + // Remark: we don't apply symmetry eagerly. So, we don't adjust H. + } + + expr e1_root = n1->m_root; + expr e2_root = n2->m_root; + entry new_n1 = *n1; + + // Following target/proof we have + // e1 -> ... -> r1 + // e2 -> ... -> r2 + // We want + // r1 -> ... -> e1 -> e2 -> ... -> r2 + invert_trans(R, e1); + new_n1.m_target = e2; + new_n1.m_proof = H; + m_entries.insert(eqc_key(R, e1), new_n1); + + // The hash code for the parents is going to change + remove_parents(R, e1); + + // force all m_root fields in e1 equivalence class to point to e2_root + expr it = e1; + do { + auto it_n = m_entries.find(eqc_key(R, it)); + lean_assert(it_n); + entry new_it_n = *it_n; + new_it_n.m_root = e2_root; + m_entries.insert(eqc_key(R, it), new_it_n); + it = new_it_n.m_next; + } while (it != e1); + + insert_parents(R, e1); + + // update next of e1_root and e2_root, and size of e2_root + r1 = m_entries.find(eqc_key(R, e1_root)); + r2 = m_entries.find(eqc_key(R, e2_root)); + lean_assert(r1 && r2); + lean_assert(r1->m_root == e2_root); + entry new_r1 = *r1; + entry new_r2 = *r2; + new_r1.m_next = r2->m_next; + new_r2.m_next = r1->m_next; + new_r2.m_size += r1->m_size; + m_entries.insert(eqc_key(R, e1_root), new_r1); + m_entries.insert(eqc_key(R, e2_root), new_r2); + lean_assert(check_invariant()); +} + +void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const & _rhs, expr const & _H) { auto & todo = get_todo(); - todo.emplace_back(R, lhs, rhs, pr); + todo.emplace_back(_R, _lhs, _rhs, _H); while (!todo.empty()) { - name R; expr lhs, rhs, pr; - std::tie(R, lhs, rhs, pr) = todo.back(); + name R; expr lhs, rhs, H; + std::tie(R, lhs, rhs, H) = todo.back(); todo.pop_back(); - // TODO(Leo): process + add_eqv_step(R, lhs, rhs, H); } } @@ -241,7 +347,7 @@ void congruence_closure::add(hypothesis_idx hidx) { hypothesis const & h = s.get_hypothesis_decl(hidx); try { expr const & type = h.get_type(); - expr p; + expr p = type; bool is_neg = is_not(type, p); if (is_neg && !is_standard(env())) return; @@ -385,11 +491,60 @@ expr congruence_closure::get_next(name const & R, expr const & e) const { } } +void congruence_closure::display_eqc(name const & R, expr const & e) const { + auto out = diagnostic(env(), ios()); + bool first = true; + expr it = e; + out << R << " {"; + do { + auto it_n = m_entries.find(eqc_key(R, it)); + if (first) first = false; else out << ", "; + out << ppb(it); + it = it_n->m_next; + } while (it != e); + out << "}"; +} + void congruence_closure::display() const { - // TODO(Leo): + auto out = diagnostic(env(), ios()); + m_entries.for_each([&](eqc_key const & k, entry const & n) { + if (k.m_expr == n.m_root) { + display_eqc(k.m_R, k.m_expr); + out << "\n"; + } + }); +} + +bool congruence_closure::check_eqc(name const & R, expr const & e) const { + expr root = get_root(R, e); + unsigned size = 0; + expr it = e; + do { + auto it_n = m_entries.find(eqc_key(R, it)); + lean_assert(it_n); + lean_assert(it_n->m_root == root); + auto it2 = it; + // following m_target fields should lead to root + while (true) { + auto it2_n = m_entries.find(eqc_key(R, it2)); + if (!it2_n->m_target) + break; + it2 = *it2_n->m_target; + } + lean_assert(it2 == root); + it = it_n->m_next; + size++; + } while (it != e); + lean_assert(m_entries.find(eqc_key(R, root))->m_size == size); + return true; } bool congruence_closure::check_invariant() const { + m_entries.for_each([&](eqc_key const & k, entry const & n) { + if (k.m_expr == n.m_root) { + lean_assert(check_eqc(k.m_R, k.m_expr)); + } + }); // TODO(Leo): return true; } diff --git a/src/library/blast/congruence_closure.h b/src/library/blast/congruence_closure.h index 86250c627..eee7c2eb6 100644 --- a/src/library/blast/congruence_closure.h +++ b/src/library/blast/congruence_closure.h @@ -43,7 +43,7 @@ class congruence_closure { // store 'target' at 'm_target', and 'H' at 'm_proof'. Both fields are none if 'e' == m_root optional m_target; optional m_proof; - unsigned m_rank; // rank of the equivalence class, it is meaningless if 'e' != m_root + unsigned m_size; // number of elements in the equivalence class, it is meaningless if 'e' != m_root }; /* Key (R, e) for the mapping (R, e) -> entry */ @@ -58,7 +58,7 @@ class congruence_closure { int operator()(eqc_key const & k1, eqc_key const & k2) const { int r = quick_cmp(k1.m_R, k2.m_R); if (r != 0) return r; - else return is_lt(k1.m_expr, k2.m_expr, true); + else return expr_quick_cmp()(k1.m_expr, k2.m_expr); } }; @@ -95,8 +95,14 @@ class congruence_closure { void mk_entry_for(name const & R, expr const & e); void add_occurrence(name const & Rp, expr const & parent, name const & Rc, expr const & child); void add_congruence_table(ext_congr_lemma const & lemma, expr const & e); - void add_eqv(name const & R, expr const & lhs, expr const & rhs, expr const & pr); + void invert_trans(name const & R, expr const & e, optional new_target, optional new_proof); + void invert_trans(name const & R, expr const & e); + void remove_parents(name const & R, expr const & e); + void insert_parents(name const & R, expr const & e); + void add_eqv_step(name const & R, expr e1, expr e2, expr const & H); + void add_eqv(name const & R, expr const & lhs, expr const & rhs, expr const & H); + void display_eqc(name const & R, expr const & e) const; public: /** \brief Register expression \c e in this data-structure. It creates entries for each sub-expression in \c e. @@ -154,6 +160,7 @@ public: /** \brief dump for debugging purposes. */ void display() const; + bool check_eqc(name const & R, expr const & e) const; bool check_invariant() const; };