From 28970ef717260a6b13dd1d87f307d7bc10635ce6 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 20 Nov 2015 12:18:52 -0800 Subject: [PATCH] feat(library/blast/congruence_closure): proof extraction --- src/library/blast/assert_cc_action.cpp | 6 +- src/library/blast/congruence_closure.cpp | 213 +++++++++++++++++++++-- src/library/blast/congruence_closure.h | 7 + tests/lean/run/blast_cc1.lean | 19 ++ 4 files changed, 223 insertions(+), 22 deletions(-) create mode 100644 tests/lean/run/blast_cc1.lean diff --git a/src/library/blast/assert_cc_action.cpp b/src/library/blast/assert_cc_action.cpp index 41ae4f294..4e91a98cb 100644 --- a/src/library/blast/assert_cc_action.cpp +++ b/src/library/blast/assert_cc_action.cpp @@ -33,10 +33,10 @@ static congruence_closure & get_cc() { action_result assert_cc_action(hypothesis_idx hidx) { congruence_closure & cc = get_cc(); + // TODO(Leo): consider a target_changed event for branch_extension. + cc.internalize(curr_state().get_target()); cc.add(hidx); - // TODO(Leo): remove the following line - return action_result::new_branch(); - cc.display(); + // cc.display(); if (cc.is_inconsistent()) { try { app_builder & b = get_app_builder(); diff --git a/src/library/blast/congruence_closure.cpp b/src/library/blast/congruence_closure.cpp index 2d7eed162..b1023d86a 100644 --- a/src/library/blast/congruence_closure.cpp +++ b/src/library/blast/congruence_closure.cpp @@ -96,6 +96,7 @@ void congruence_closure::mk_entry_for(name const & R, expr const & e) { n.m_root = e; n.m_cg_root = e; n.m_size = 1; + n.m_flipped = false; m_entries.insert(eqc_key(R, e), n); } @@ -373,7 +374,7 @@ void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, exp check_iff_true(k); } -void congruence_closure::internalize(name const & R, expr const & e) { +void congruence_closure::internalize_core(name const & R, expr const & e) { lean_assert(closed(e)); if (has_expr_metavar(e)) return; @@ -390,13 +391,13 @@ void congruence_closure::internalize(name const & R, expr const & e) { return; case expr_kind::Macro: for (unsigned i = 0; i < macro_num_args(e); i++) - internalize(R, macro_arg(e, i)); + internalize_core(R, macro_arg(e, i)); mk_entry_for(R, e); break; case expr_kind::Pi: if (is_arrow(e) && is_prop(binding_domain(e)) && is_prop(binding_body(e))) { - internalize(R, binding_domain(e)); - internalize(R, binding_body(e)); + internalize_core(R, binding_domain(e)); + internalize_core(R, binding_body(e)); } if (is_prop(e)) { mk_entry_for(R, e); @@ -411,13 +412,13 @@ void congruence_closure::internalize(name const & R, expr const & e) { for (expr const & arg : args) { lean_assert(*it); if (auto R1 = head(*it)) { - internalize(*R1, arg); + internalize_core(*R1, arg); add_occurrence(R, e, *R1, arg); } it = &tail(*it); } if (!lemma->m_fixed_fun) { - internalize(get_eq_name(), fn); + internalize_core(get_eq_name(), fn); add_occurrence(get_eq_name(), e, get_eq_name(), fn); } add_congruence_table(*lemma, e); @@ -426,6 +427,12 @@ void congruence_closure::internalize(name const & R, expr const & e) { }} } +void congruence_closure::internalize(name const & R, expr const & e) { + flet set_cc(g_cc, this); + internalize_core(R, e); + process_todo(); +} + void congruence_closure::internalize(expr const & e) { if (is_prop(e)) internalize(get_iff_name(), e); @@ -453,8 +460,9 @@ void congruence_closure::invert_trans(name const & R, expr const & e, optionalm_target) invert_trans(R, *new_n.m_target, some_expr(e), new_n.m_proof); - new_n.m_target = new_target; - new_n.m_proof = new_proof; + new_n.m_target = new_target; + new_n.m_proof = new_proof; + new_n.m_flipped = !new_n.m_flipped; m_entries.insert(k, new_n); } void congruence_closure::invert_trans(name const & R, expr const & e) { @@ -496,6 +504,7 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con auto r1 = m_entries.find(eqc_key(R, n1->m_root)); auto r2 = m_entries.find(eqc_key(R, n2->m_root)); lean_assert(r1 && r2); + bool flipped = false; // We want r2 to be the root of the combined class. @@ -504,6 +513,7 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con std::swap(n1, n2); std::swap(r1, r2); // Remark: we don't apply symmetry eagerly. So, we don't adjust H. + flipped = true; } expr e1_root = n1->m_root; @@ -516,8 +526,9 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con // We want // r1 -> ... -> e1 -> e2 -> ... -> r2 invert_trans(R, e1); - new_n1.m_target = e2; - new_n1.m_proof = H; + new_n1.m_target = e2; + new_n1.m_proof = H; + new_n1.m_flipped = flipped; m_entries.insert(eqc_key(R, e1), new_n1); // The hash code for the parents is going to change @@ -565,9 +576,8 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con m_parents.insert(k2, ps2); } -void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const & _rhs, expr const & _H) { +void congruence_closure::process_todo() { auto & todo = get_todo(); - todo.emplace_back(_R, _lhs, _rhs, _H); while (!todo.empty()) { name R; expr lhs, rhs, H; std::tie(R, lhs, rhs, H) = todo.back(); @@ -576,6 +586,11 @@ void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const } } +void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const & _rhs, expr const & _H) { + get_todo().emplace_back(_R, _lhs, _rhs, _H); + process_todo(); +} + void congruence_closure::add(hypothesis_idx hidx) { if (is_inconsistent()) return; @@ -593,15 +608,15 @@ void congruence_closure::add(hypothesis_idx hidx) { name R; expr lhs, rhs; if (is_relation_app(p, R, lhs, rhs)) { if (is_neg) { - internalize(get_iff_name(), p); + internalize_core(get_iff_name(), p); add_eqv(get_iff_name(), p, mk_false(), b.mk_iff_false_intro(h.get_self())); } else { - internalize(R, lhs); - internalize(R, rhs); + internalize_core(R, lhs); + internalize_core(R, rhs); add_eqv(R, lhs, rhs, h.get_self()); } } else if (is_prop(p)) { - internalize(get_iff_name(), p); + internalize_core(get_iff_name(), p); if (is_neg) { add_eqv(get_iff_name(), p, mk_false(), b.mk_iff_false_intro(h.get_self())); } else { @@ -619,10 +634,170 @@ bool congruence_closure::is_eqv(name const & R, expr const & e1, expr const & e2 return n1->m_root == n2->m_root; } +expr congruence_closure::mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H) const { + if (H == *g_congr_mark) { + app_builder & b = get_app_builder(); + buffer lhs_args, rhs_args; + expr const & lhs_fn = get_app_args(lhs, lhs_args); + expr const & rhs_fn = get_app_args(rhs, rhs_args); + lean_assert(lhs_args.size() == rhs_args.size()); + auto lemma = mk_ext_congr_lemma(R, lhs_fn, lhs_args.size()); + lean_assert(lemma); + if (lemma->m_fixed_fun) { + list> const * it1 = &lemma->m_rel_names; + list const * it2 = &lemma->m_congr_lemma.get_arg_kinds(); + buffer lemma_args; + for (unsigned i = 0; i < lhs_args.size(); i++) { + lean_assert(*it1 && *it2); + switch (head(*it2)) { + case congr_arg_kind::Eq: + lean_assert(head(*it1)); + lemma_args.push_back(lhs_args[i]); + lemma_args.push_back(rhs_args[i]); + lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i])); + break; + case congr_arg_kind::Fixed: + lemma_args.push_back(lhs_args[i]); + break; + case congr_arg_kind::Cast: + lemma_args.push_back(lhs_args[i]); + lemma_args.push_back(rhs_args[i]); + break; + } + it1 = &(tail(*it1)); + it2 = &(tail(*it2)); + } + expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args); + if (lemma->m_lift_needed) + r = b.lift_from_eq(R, r); + return r; + } else { + optional r; + unsigned i = 0; + if (!is_def_eq(lhs_fn, rhs_fn)) { + r = get_eqv_proof(get_eq_name(), lhs_fn, rhs_fn); + } else { + for (; i < lhs_args.size(); i++) { + if (!is_def_eq(lhs_args[i], rhs_args[i])) { + expr g = mk_app(lhs_fn, i, lhs_args.data()); + expr Hi = *get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i]); + r = b.mk_congr_arg(g, Hi); + i++; + break; + } + } + if (!r) { + // lhs and rhs are definitionally equal + r = b.mk_eq_refl(lhs); + } + } + lean_assert(r); + for (; i < lhs_args.size(); i++) { + if (is_def_eq(lhs_args[i], rhs_args[i])) { + r = b.mk_congr_fun(*r, lhs_args[i]); + } else { + expr Hi = *get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i]); + r = b.mk_congr(*r, Hi); + } + } + if (lemma->m_lift_needed) + r = b.lift_from_eq(R, *r); + return *r; + } + } else if (H == *g_iff_true_mark) { + // TODO(Leo): + lean_unreachable(); + } else { + return H; + } +} + +static expr flip_proof(name const & R, expr const & H, bool flipped) { + if (!flipped) { + return H; + } else if (H == *g_congr_mark) { + return H; + } else if (H == *g_iff_true_mark) { + return H; + } else { + return get_app_builder().mk_symm(R, H); + } +} + +static expr mk_trans(name const & R, optional const & H1, expr const & H2) { + return !H1 ? H2 : get_app_builder().mk_trans(R, *H1, H2); +} + optional congruence_closure::get_eqv_proof(name const & R, expr const & e1, expr const & e2) const { - // TODO(Leo): - std::cout << R << e1 << e2 << "\n"; - return none_expr(); + app_builder & b = get_app_builder(); + if (has_expr_metavar(e1) || has_expr_metavar(e2)) return none_expr(); + if (is_def_eq(e1, e2)) + return some_expr(b.lift_from_eq(R, b.mk_eq_refl(e1))); + auto n1 = m_entries.find(eqc_key(R, e1)); + if (!n1) return none_expr(); + auto n2 = m_entries.find(eqc_key(R, e2)); + if (!n2) return none_expr(); + if (n1->m_root != n2->m_root) return none_expr(); + // 1. Retrieve "path" from e1 to root + buffer path1, Hs1; + rb_tree visited; + expr it1 = e1; + while (true) { + visited.insert(it1); + auto it1_n = m_entries.find(eqc_key(R, it1)); + lean_assert(it1_n); + if (!it1_n->m_target) + break; + path1.push_back(*it1_n->m_target); + Hs1.push_back(flip_proof(R, *it1_n->m_proof, it1_n->m_flipped)); + it1 = *it1_n->m_target; + } + lean_assert(it1 == n1->m_root); + // 2. The path from e2 to root must have at least one element c in visited + // Retrieve "path" from e2 to c + buffer path2, Hs2; + path2.push_back(e2); + expr it2 = e2; + while (true) { + if (visited.contains(it2)) + break; // found common + auto it2_n = m_entries.find(eqc_key(R, it2)); + lean_assert(it2_n); + lean_assert(it2_n->m_target); + path2.push_back(it2); + Hs2.push_back(flip_proof(R, *it2_n->m_proof, !it2_n->m_flipped)); + it2 = *it2_n->m_target; + } + // it2 is the common element... + // 3. Shink path1/Hs1 until we find it2 (the common element) + while (true) { + if (path1.empty()) { + lean_assert(it2 == e1 && e1 == n2->m_root); + break; + } + if (path1.back() == it2) { + // found it! + break; + } + path1.pop_back(); + Hs1.pop_back(); + } + + // 4. Build transitivity proof + optional pr; + expr lhs = e1; + for (unsigned i = 0; i < path1.size(); i++) { + pr = mk_trans(R, pr, mk_proof(R, lhs, path1[i], Hs1[i])); + lhs = path1[i]; + } + unsigned i = Hs2.size(); + while (i > 0) { + --i; + pr = mk_trans(R, pr, mk_proof(R, lhs, path2[i], Hs2[i])); + lhs = path2[i]; + } + lean_assert(pr); + return pr; } bool congruence_closure::is_uneqv(name const & R, expr const & e1, expr const & e2) const { diff --git a/src/library/blast/congruence_closure.h b/src/library/blast/congruence_closure.h index 736840a5d..ec8c3229b 100644 --- a/src/library/blast/congruence_closure.h +++ b/src/library/blast/congruence_closure.h @@ -43,7 +43,9 @@ class congruence_closure { // store 'target' at 'm_target', and 'H' at 'm_proof'. Both fields are none if 'e' == m_root optional m_target; optional m_proof; + bool m_flipped; // proof has been flipped unsigned m_size; // number of elements in the equivalence class, it is meaningless if 'e' != m_root + }; /* Key (R, e) for the mapping (R, e) -> entry */ @@ -97,6 +99,9 @@ class congruence_closure { parents m_parents; congruences m_congruences; + void internalize_core(name const & R, expr const & e); + void process_todo(); + int compare_symm(name const & R, expr lhs1, expr rhs1, expr lhs2, expr rhs2) const; int compare_root(name const & R, expr e1, expr e2) const; unsigned symm_hash(name const & R, expr const & lhs, expr const & rhs) const; @@ -113,6 +118,8 @@ class congruence_closure { void add_eqv_step(name const & R, expr e1, expr e2, expr const & H); void add_eqv(name const & R, expr const & lhs, expr const & rhs, expr const & H); + expr mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H) const; + void display_eqc(name const & R, expr const & e) const; public: void initialize(); diff --git a/tests/lean/run/blast_cc1.lean b/tests/lean/run/blast_cc1.lean new file mode 100644 index 000000000..1f64b89d3 --- /dev/null +++ b/tests/lean/run/blast_cc1.lean @@ -0,0 +1,19 @@ +import data.list + +constant f {A : Type} : A → A → A +constant g : nat → nat +set_option blast.subst false + +example (a b c : nat) : a = b → g a == g b := +by blast + +example (a b c : nat) : a = b → c = b → f (f a b) (g c) = f (f c a) (g b) := +by blast + +example (a b c d e x y : nat) : a = b → a = x → b = y → c = d → c = e → c = b → a = e := +by blast + +open perm + +example (a b c d : list nat) : a ~ b → c ~ b → d ~ c → a ~ d := +by blast