From 8f368cebbfe91f78a7d3a4f5ce3fff225503ac9d Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 20 Nov 2015 13:46:29 -0800 Subject: [PATCH] feat(library/blast/congruence_closure): lift equalities --- src/library/blast/congruence_closure.cpp | 137 +++++++++++++++-------- src/library/blast/congruence_closure.h | 10 +- tests/lean/run/blast_cc1.lean | 3 + tests/lean/run/blast_cc2.lean | 24 ++++ 4 files changed, 127 insertions(+), 47 deletions(-) create mode 100644 tests/lean/run/blast_cc2.lean diff --git a/src/library/blast/congruence_closure.cpp b/src/library/blast/congruence_closure.cpp index b1023d86a..c2e3b1bd4 100644 --- a/src/library/blast/congruence_closure.cpp +++ b/src/library/blast/congruence_closure.cpp @@ -71,9 +71,28 @@ struct ext_congr_lemma { /* We use the following cache for user-defined lemmas and automatically generated ones. */ typedef std::unordered_map, congr_lemma_key_hash_fn, congr_lemma_key_eq_fn> congr_cache; +typedef std::tuple cc_todo_entry; + +static expr * g_congr_mark = nullptr; // dummy congruence proof, it is just a placeholder. +static expr * g_iff_true_mark = nullptr; // dummy iff_true proof, it is just a placeholder. +static expr * g_lift_mark = nullptr; // dummy lift eq proof, it is just a placeholder. + +/* Small hack for not storing a pointer to the congruence_closure object + at congruence_closure::congr_key_cmp */ +LEAN_THREAD_PTR(congruence_closure, g_cc); LEAN_THREAD_VALUE(congr_cache *, g_congr_cache, nullptr); +MK_THREAD_LOCAL_GET_DEF(std::vector, get_todo); + +static void clear_todo() { + get_todo().clear(); +} + +static void push_todo(name const & R, expr const & lhs, expr const & rhs, expr const & H) { + get_todo().emplace_back(R, lhs, rhs, H); +} + scope_congruence_closure::scope_congruence_closure(): m_old_cache(g_congr_cache) { g_congr_cache = new congr_cache(); @@ -85,11 +104,11 @@ scope_congruence_closure::~scope_congruence_closure() { } void congruence_closure::initialize() { - mk_entry_for(get_iff_name(), mk_true()); - mk_entry_for(get_iff_name(), mk_false()); + mk_entry_core(get_iff_name(), mk_true()); + mk_entry_core(get_iff_name(), mk_false()); } -void congruence_closure::mk_entry_for(name const & R, expr const & e) { +void congruence_closure::mk_entry_core(name const & R, expr const & e) { lean_assert(!m_entries.find(eqc_key(R, e))); entry n; n.m_next = e; @@ -98,6 +117,29 @@ void congruence_closure::mk_entry_for(name const & R, expr const & e) { n.m_size = 1; n.m_flipped = false; m_entries.insert(eqc_key(R, e), n); + if (R != get_eq_name()) { + // lift equalities to R + auto n = m_entries.find(eqc_key(get_eq_name(), e)); + if (n) { + // e has an eq equivalence class + expr it = n->m_next; + while (it != e) { + if (m_entries.find(eqc_key(R, it))) { + push_todo(R, e, it, *g_lift_mark); + break; + } + auto it_n = m_entries.find(eqc_key(get_eq_name(), it)); + lean_assert(it_n); + it = it_n->m_next; + } + } + } +} + +void congruence_closure::mk_entry(name const & R, expr const & e) { + if (m_entries.find(eqc_key(R, e))) + return; + mk_entry_core(R, e); } static optional mk_ext_congr_lemma_core(name const & R, expr const & fn, unsigned nargs) { @@ -153,6 +195,13 @@ static optional mk_ext_congr_lemma(name const & R, expr const & return r; } +void congruence_closure::update_non_eq_relations(name const & R) { + if (R == get_eq_name()) + return; + if (std::find(m_non_eq_relations.begin(), m_non_eq_relations.end(), R) == m_non_eq_relations.end()) + m_non_eq_relations = cons(R, m_non_eq_relations); +} + void congruence_closure::add_occurrence(name const & Rp, expr const & parent, name const & Rc, expr const & child) { child_key k(Rc, child); parent_occ_set ps; @@ -162,10 +211,6 @@ void congruence_closure::add_occurrence(name const & Rp, expr const & parent, na m_parents.insert(k, ps); } -/* Small hack for not storing a pointer to the congruence_closure object - at congruence_closure::congr_key_cmp */ -LEAN_THREAD_PTR(congruence_closure, g_cc); - /* Auxiliary function for comparing (lhs1 ~ rhs1) and (lhs2 ~ rhs2), when ~ is symmetric. It returns 0 (equal) for (a ~ b) (b ~ a) */ @@ -323,17 +368,6 @@ auto congruence_closure::mk_congr_key(ext_congr_lemma const & lemma, expr const return k; } -static expr * g_congr_mark = nullptr; // dummy congruence proof, it is just a placeholder. -static expr * g_iff_true_mark = nullptr; // dummy iff_true proof, it is just a placeholder. - -typedef std::tuple cc_todo_entry; - -MK_THREAD_LOCAL_GET_DEF(std::vector, get_todo); - -static void clear_todo() { - get_todo().clear(); -} - void congruence_closure::check_iff_true(congr_key const & k) { expr const & e = k.m_expr; name R; expr lhs, rhs; @@ -353,7 +387,7 @@ void congruence_closure::check_iff_true(congr_key const & k) { if (lhs != rhs) return; // Add e <-> true - get_todo().emplace_back(get_iff_name(), e, mk_true(), *g_iff_true_mark); + push_todo(get_iff_name(), e, mk_true(), *g_iff_true_mark); } void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, expr const & e) { @@ -367,7 +401,7 @@ void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, exp new_entry.m_cg_root = old_k->m_expr; m_entries.insert(k, new_entry); // 2. Put new equivalence in the TODO queue - get_todo().emplace_back(lemma.m_R, e, old_k->m_expr, *g_congr_mark); + push_todo(lemma.m_R, e, old_k->m_expr, *g_congr_mark); } else { m_congruences.insert(k); } @@ -380,6 +414,7 @@ void congruence_closure::internalize_core(name const & R, expr const & e) { return; if (m_entries.find(eqc_key(R, e))) return; // e has already been internalized + update_non_eq_relations(R); switch (e.kind()) { case expr_kind::Var: case expr_kind::Meta: lean_unreachable(); @@ -387,12 +422,12 @@ void congruence_closure::internalize_core(name const & R, expr const & e) { return; case expr_kind::Constant: case expr_kind::Local: case expr_kind::Lambda: - mk_entry_for(R, e); + mk_entry_core(R, e); return; case expr_kind::Macro: for (unsigned i = 0; i < macro_num_args(e); i++) internalize_core(R, macro_arg(e, i)); - mk_entry_for(R, e); + mk_entry_core(R, e); break; case expr_kind::Pi: if (is_arrow(e) && is_prop(binding_domain(e)) && is_prop(binding_body(e))) { @@ -400,11 +435,11 @@ void congruence_closure::internalize_core(name const & R, expr const & e) { internalize_core(R, binding_body(e)); } if (is_prop(e)) { - mk_entry_for(R, e); + mk_entry_core(R, e); } return; case expr_kind::App: { - mk_entry_for(R, e); + mk_entry_core(R, e); buffer args; expr const & fn = get_app_args(e, args); if (auto lemma = mk_ext_congr_lemma(R, fn, args.size())) { @@ -485,6 +520,8 @@ void congruence_closure::remove_parents(name const & R, expr const & e) { void congruence_closure::reinsert_parents(name const & R, expr const & e) { auto ps = m_parents.find(child_key(R, e)); if (!ps) return; + // TODO(Leo): consider the following optimization: + // 1- remove from ps any parent that is not a congruence root anymore ps->for_each([&](parent_occ const & p) { expr const & fn = get_app_fn(p.m_expr); unsigned nargs = get_app_num_args(p.m_expr); @@ -564,16 +601,29 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con // copy e1_root parents to e2_root child_key k1(R, e1_root); auto ps1 = m_parents.find(k1); - if (!ps1) return; // e1_root doesn't have parents - parent_occ_set ps2; - child_key k2(R, e2_root); - if (auto it = m_parents.find(k2)) - ps2 = *it; - ps1->for_each([&](parent_occ const & p) { - ps2.insert(p); - }); - m_parents.erase(k1); - m_parents.insert(k2, ps2); + if (ps1) { + parent_occ_set ps2; + child_key k2(R, e2_root); + if (auto it = m_parents.find(k2)) + ps2 = *it; + ps1->for_each([&](parent_occ const & p) { + ps2.insert(p); + }); + m_parents.erase(k1); + m_parents.insert(k2, ps2); + } + + // lift equivalence + if (R == get_eq_name()) { + for (name const & R2 : m_non_eq_relations) { + if (m_entries.find(eqc_key(R2, e1)) || + m_entries.find(eqc_key(R2, e2))) { + mk_entry(R2, e1); + mk_entry(R2, e2); + push_todo(R2, e1, e2, *g_lift_mark); + } + } + } } void congruence_closure::process_todo() { @@ -586,8 +636,8 @@ void congruence_closure::process_todo() { } } -void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const & _rhs, expr const & _H) { - get_todo().emplace_back(_R, _lhs, _rhs, _H); +void congruence_closure::add_eqv(name const & R, expr const & lhs, expr const & rhs, expr const & H) { + push_todo(R, lhs, rhs, H); process_todo(); } @@ -707,17 +757,16 @@ expr congruence_closure::mk_proof(name const & R, expr const & lhs, expr const & } else if (H == *g_iff_true_mark) { // TODO(Leo): lean_unreachable(); + } else if (H == *g_lift_mark) { + expr H1 = *get_eqv_proof(get_eq_name(), lhs, rhs); + return get_app_builder().lift_from_eq(R, H1); } else { return H; } } static expr flip_proof(name const & R, expr const & H, bool flipped) { - if (!flipped) { - return H; - } else if (H == *g_congr_mark) { - return H; - } else if (H == *g_iff_true_mark) { + if (!flipped || H == *g_congr_mark || H == *g_iff_true_mark || H == *g_lift_mark) { return H; } else { return get_app_builder().mk_symm(R, H); @@ -756,7 +805,6 @@ optional congruence_closure::get_eqv_proof(name const & R, expr const & e1 // 2. The path from e2 to root must have at least one element c in visited // Retrieve "path" from e2 to c buffer path2, Hs2; - path2.push_back(e2); expr it2 = e2; while (true) { if (visited.contains(it2)) @@ -772,7 +820,7 @@ optional congruence_closure::get_eqv_proof(name const & R, expr const & e1 // 3. Shink path1/Hs1 until we find it2 (the common element) while (true) { if (path1.empty()) { - lean_assert(it2 == e1 && e1 == n2->m_root); + lean_assert(it2 == e1); break; } if (path1.back() == it2) { @@ -986,7 +1034,6 @@ bool congruence_closure::check_invariant() const { lean_assert(check_eqc(k.m_R, k.m_expr)); } }); - // TODO(Leo): return true; } @@ -994,10 +1041,12 @@ void initialize_congruence_closure() { name prefix = name::mk_internal_unique_name(); g_congr_mark = new expr(mk_constant(name(prefix, "[congruence]"))); g_iff_true_mark = new expr(mk_constant(name(prefix, "[iff-true]"))); + g_lift_mark = new expr(mk_constant(name(prefix, "[lift]"))); } void finalize_congruence_closure() { delete g_congr_mark; delete g_iff_true_mark; + delete g_lift_mark; } }} diff --git a/src/library/blast/congruence_closure.h b/src/library/blast/congruence_closure.h index ec8c3229b..2ab467c7c 100644 --- a/src/library/blast/congruence_closure.h +++ b/src/library/blast/congruence_closure.h @@ -98,6 +98,9 @@ class congruence_closure { entries m_entries; parents m_parents; congruences m_congruences; + list m_non_eq_relations; + + void update_non_eq_relations(name const & R); void internalize_core(name const & R, expr const & e); void process_todo(); @@ -108,7 +111,8 @@ class congruence_closure { congr_key mk_congr_key(ext_congr_lemma const & lemma, expr const & e) const; void check_iff_true(congr_key const & k); - void mk_entry_for(name const & R, expr const & e); + void mk_entry_core(name const & R, expr const & e); + void mk_entry(name const & R, expr const & e); void add_occurrence(name const & Rp, expr const & parent, name const & Rc, expr const & child); void add_congruence_table(ext_congr_lemma const & lemma, expr const & e); void invert_trans(name const & R, expr const & e, optional new_target, optional new_proof); @@ -157,9 +161,9 @@ public: /** \brief Return the proof of inconsistency */ optional get_inconsistency_proof() const; - /** \brief Return true iff 'e1' and 'e2' are in the same equivalence class for relation \c rel_name. */ + /** \brief Return true iff 'e1' and 'e2' are in the same equivalence class for relation \c R. */ bool is_eqv(name const & R, expr const & e1, expr const & e2) const; - optional get_eqv_proof(name const & rel_name, expr const & e1, expr const & e2) const; + optional get_eqv_proof(name const & R, expr const & e1, expr const & e2) const; /** \brief Return true iff `e1 ~ e2` is in the equivalence class of false for iff. */ bool is_uneqv(name const & R, expr const & e1, expr const & e2) const; diff --git a/tests/lean/run/blast_cc1.lean b/tests/lean/run/blast_cc1.lean index 1f64b89d3..60f14b5b8 100644 --- a/tests/lean/run/blast_cc1.lean +++ b/tests/lean/run/blast_cc1.lean @@ -17,3 +17,6 @@ open perm example (a b c d : list nat) : a ~ b → c ~ b → d ~ c → a ~ d := by blast + +example (a b c d : list nat) : a ~ b → c ~ b → d = c → a ~ d := +by blast diff --git a/tests/lean/run/blast_cc2.lean b/tests/lean/run/blast_cc2.lean new file mode 100644 index 000000000..a97880c60 --- /dev/null +++ b/tests/lean/run/blast_cc2.lean @@ -0,0 +1,24 @@ +set_option blast.init_depth 10 +set_option blast.inc_depth 100 +set_option blast.subst false + +example (a b c d : nat) : a == b → b = c → c == d → a == d := +by blast + +example (a b c d : nat) : a = b → b = c → c == d → a == d := +by blast + +example (a b c d : nat) : a = b → b == c → c == d → a == d := +by blast + +example (a b c d : nat) : a == b → b == c → c = d → a == d := +by blast + +example (a b c d : nat) : a == b → b = c → c = d → a == d := +by blast + +example (a b c d : nat) : a = b → b = c → c = d → a == d := +by blast + +example (a b c d : nat) : a = b → b == c → c = d → a == d := +by blast