feat(library/blast/congruence_closure): proof extraction

This commit is contained in:
Leonardo de Moura 2015-11-20 12:18:52 -08:00
parent 1e3f549c39
commit 28970ef717
4 changed files with 223 additions and 22 deletions

View file

@ -33,10 +33,10 @@ static congruence_closure & get_cc() {
action_result assert_cc_action(hypothesis_idx hidx) { action_result assert_cc_action(hypothesis_idx hidx) {
congruence_closure & cc = get_cc(); congruence_closure & cc = get_cc();
// TODO(Leo): consider a target_changed event for branch_extension.
cc.internalize(curr_state().get_target());
cc.add(hidx); cc.add(hidx);
// TODO(Leo): remove the following line // cc.display();
return action_result::new_branch();
cc.display();
if (cc.is_inconsistent()) { if (cc.is_inconsistent()) {
try { try {
app_builder & b = get_app_builder(); app_builder & b = get_app_builder();

View file

@ -96,6 +96,7 @@ void congruence_closure::mk_entry_for(name const & R, expr const & e) {
n.m_root = e; n.m_root = e;
n.m_cg_root = e; n.m_cg_root = e;
n.m_size = 1; n.m_size = 1;
n.m_flipped = false;
m_entries.insert(eqc_key(R, e), n); m_entries.insert(eqc_key(R, e), n);
} }
@ -373,7 +374,7 @@ void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, exp
check_iff_true(k); check_iff_true(k);
} }
void congruence_closure::internalize(name const & R, expr const & e) { void congruence_closure::internalize_core(name const & R, expr const & e) {
lean_assert(closed(e)); lean_assert(closed(e));
if (has_expr_metavar(e)) if (has_expr_metavar(e))
return; return;
@ -390,13 +391,13 @@ void congruence_closure::internalize(name const & R, expr const & e) {
return; return;
case expr_kind::Macro: case expr_kind::Macro:
for (unsigned i = 0; i < macro_num_args(e); i++) for (unsigned i = 0; i < macro_num_args(e); i++)
internalize(R, macro_arg(e, i)); internalize_core(R, macro_arg(e, i));
mk_entry_for(R, e); mk_entry_for(R, e);
break; break;
case expr_kind::Pi: case expr_kind::Pi:
if (is_arrow(e) && is_prop(binding_domain(e)) && is_prop(binding_body(e))) { if (is_arrow(e) && is_prop(binding_domain(e)) && is_prop(binding_body(e))) {
internalize(R, binding_domain(e)); internalize_core(R, binding_domain(e));
internalize(R, binding_body(e)); internalize_core(R, binding_body(e));
} }
if (is_prop(e)) { if (is_prop(e)) {
mk_entry_for(R, e); mk_entry_for(R, e);
@ -411,13 +412,13 @@ void congruence_closure::internalize(name const & R, expr const & e) {
for (expr const & arg : args) { for (expr const & arg : args) {
lean_assert(*it); lean_assert(*it);
if (auto R1 = head(*it)) { if (auto R1 = head(*it)) {
internalize(*R1, arg); internalize_core(*R1, arg);
add_occurrence(R, e, *R1, arg); add_occurrence(R, e, *R1, arg);
} }
it = &tail(*it); it = &tail(*it);
} }
if (!lemma->m_fixed_fun) { if (!lemma->m_fixed_fun) {
internalize(get_eq_name(), fn); internalize_core(get_eq_name(), fn);
add_occurrence(get_eq_name(), e, get_eq_name(), fn); add_occurrence(get_eq_name(), e, get_eq_name(), fn);
} }
add_congruence_table(*lemma, e); add_congruence_table(*lemma, e);
@ -426,6 +427,12 @@ void congruence_closure::internalize(name const & R, expr const & e) {
}} }}
} }
void congruence_closure::internalize(name const & R, expr const & e) {
flet<congruence_closure *> set_cc(g_cc, this);
internalize_core(R, e);
process_todo();
}
void congruence_closure::internalize(expr const & e) { void congruence_closure::internalize(expr const & e) {
if (is_prop(e)) if (is_prop(e))
internalize(get_iff_name(), e); internalize(get_iff_name(), e);
@ -455,6 +462,7 @@ void congruence_closure::invert_trans(name const & R, expr const & e, optional<e
invert_trans(R, *new_n.m_target, some_expr(e), new_n.m_proof); invert_trans(R, *new_n.m_target, some_expr(e), new_n.m_proof);
new_n.m_target = new_target; new_n.m_target = new_target;
new_n.m_proof = new_proof; new_n.m_proof = new_proof;
new_n.m_flipped = !new_n.m_flipped;
m_entries.insert(k, new_n); m_entries.insert(k, new_n);
} }
void congruence_closure::invert_trans(name const & R, expr const & e) { void congruence_closure::invert_trans(name const & R, expr const & e) {
@ -496,6 +504,7 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
auto r1 = m_entries.find(eqc_key(R, n1->m_root)); auto r1 = m_entries.find(eqc_key(R, n1->m_root));
auto r2 = m_entries.find(eqc_key(R, n2->m_root)); auto r2 = m_entries.find(eqc_key(R, n2->m_root));
lean_assert(r1 && r2); lean_assert(r1 && r2);
bool flipped = false;
// We want r2 to be the root of the combined class. // We want r2 to be the root of the combined class.
@ -504,6 +513,7 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
std::swap(n1, n2); std::swap(n1, n2);
std::swap(r1, r2); std::swap(r1, r2);
// Remark: we don't apply symmetry eagerly. So, we don't adjust H. // Remark: we don't apply symmetry eagerly. So, we don't adjust H.
flipped = true;
} }
expr e1_root = n1->m_root; expr e1_root = n1->m_root;
@ -518,6 +528,7 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
invert_trans(R, e1); invert_trans(R, e1);
new_n1.m_target = e2; new_n1.m_target = e2;
new_n1.m_proof = H; new_n1.m_proof = H;
new_n1.m_flipped = flipped;
m_entries.insert(eqc_key(R, e1), new_n1); m_entries.insert(eqc_key(R, e1), new_n1);
// The hash code for the parents is going to change // The hash code for the parents is going to change
@ -565,9 +576,8 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
m_parents.insert(k2, ps2); m_parents.insert(k2, ps2);
} }
void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const & _rhs, expr const & _H) { void congruence_closure::process_todo() {
auto & todo = get_todo(); auto & todo = get_todo();
todo.emplace_back(_R, _lhs, _rhs, _H);
while (!todo.empty()) { while (!todo.empty()) {
name R; expr lhs, rhs, H; name R; expr lhs, rhs, H;
std::tie(R, lhs, rhs, H) = todo.back(); std::tie(R, lhs, rhs, H) = todo.back();
@ -576,6 +586,11 @@ void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const
} }
} }
void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const & _rhs, expr const & _H) {
get_todo().emplace_back(_R, _lhs, _rhs, _H);
process_todo();
}
void congruence_closure::add(hypothesis_idx hidx) { void congruence_closure::add(hypothesis_idx hidx) {
if (is_inconsistent()) if (is_inconsistent())
return; return;
@ -593,15 +608,15 @@ void congruence_closure::add(hypothesis_idx hidx) {
name R; expr lhs, rhs; name R; expr lhs, rhs;
if (is_relation_app(p, R, lhs, rhs)) { if (is_relation_app(p, R, lhs, rhs)) {
if (is_neg) { if (is_neg) {
internalize(get_iff_name(), p); internalize_core(get_iff_name(), p);
add_eqv(get_iff_name(), p, mk_false(), b.mk_iff_false_intro(h.get_self())); add_eqv(get_iff_name(), p, mk_false(), b.mk_iff_false_intro(h.get_self()));
} else { } else {
internalize(R, lhs); internalize_core(R, lhs);
internalize(R, rhs); internalize_core(R, rhs);
add_eqv(R, lhs, rhs, h.get_self()); add_eqv(R, lhs, rhs, h.get_self());
} }
} else if (is_prop(p)) { } else if (is_prop(p)) {
internalize(get_iff_name(), p); internalize_core(get_iff_name(), p);
if (is_neg) { if (is_neg) {
add_eqv(get_iff_name(), p, mk_false(), b.mk_iff_false_intro(h.get_self())); add_eqv(get_iff_name(), p, mk_false(), b.mk_iff_false_intro(h.get_self()));
} else { } else {
@ -619,10 +634,170 @@ bool congruence_closure::is_eqv(name const & R, expr const & e1, expr const & e2
return n1->m_root == n2->m_root; return n1->m_root == n2->m_root;
} }
optional<expr> congruence_closure::get_eqv_proof(name const & R, expr const & e1, expr const & e2) const { expr congruence_closure::mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H) const {
if (H == *g_congr_mark) {
app_builder & b = get_app_builder();
buffer<expr> lhs_args, rhs_args;
expr const & lhs_fn = get_app_args(lhs, lhs_args);
expr const & rhs_fn = get_app_args(rhs, rhs_args);
lean_assert(lhs_args.size() == rhs_args.size());
auto lemma = mk_ext_congr_lemma(R, lhs_fn, lhs_args.size());
lean_assert(lemma);
if (lemma->m_fixed_fun) {
list<optional<name>> const * it1 = &lemma->m_rel_names;
list<congr_arg_kind> const * it2 = &lemma->m_congr_lemma.get_arg_kinds();
buffer<expr> lemma_args;
for (unsigned i = 0; i < lhs_args.size(); i++) {
lean_assert(*it1 && *it2);
switch (head(*it2)) {
case congr_arg_kind::Eq:
lean_assert(head(*it1));
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i]));
break;
case congr_arg_kind::Fixed:
lemma_args.push_back(lhs_args[i]);
break;
case congr_arg_kind::Cast:
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
break;
}
it1 = &(tail(*it1));
it2 = &(tail(*it2));
}
expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args);
if (lemma->m_lift_needed)
r = b.lift_from_eq(R, r);
return r;
} else {
optional<expr> r;
unsigned i = 0;
if (!is_def_eq(lhs_fn, rhs_fn)) {
r = get_eqv_proof(get_eq_name(), lhs_fn, rhs_fn);
} else {
for (; i < lhs_args.size(); i++) {
if (!is_def_eq(lhs_args[i], rhs_args[i])) {
expr g = mk_app(lhs_fn, i, lhs_args.data());
expr Hi = *get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i]);
r = b.mk_congr_arg(g, Hi);
i++;
break;
}
}
if (!r) {
// lhs and rhs are definitionally equal
r = b.mk_eq_refl(lhs);
}
}
lean_assert(r);
for (; i < lhs_args.size(); i++) {
if (is_def_eq(lhs_args[i], rhs_args[i])) {
r = b.mk_congr_fun(*r, lhs_args[i]);
} else {
expr Hi = *get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i]);
r = b.mk_congr(*r, Hi);
}
}
if (lemma->m_lift_needed)
r = b.lift_from_eq(R, *r);
return *r;
}
} else if (H == *g_iff_true_mark) {
// TODO(Leo): // TODO(Leo):
std::cout << R << e1 << e2 << "\n"; lean_unreachable();
return none_expr(); } else {
return H;
}
}
static expr flip_proof(name const & R, expr const & H, bool flipped) {
if (!flipped) {
return H;
} else if (H == *g_congr_mark) {
return H;
} else if (H == *g_iff_true_mark) {
return H;
} else {
return get_app_builder().mk_symm(R, H);
}
}
static expr mk_trans(name const & R, optional<expr> const & H1, expr const & H2) {
return !H1 ? H2 : get_app_builder().mk_trans(R, *H1, H2);
}
optional<expr> congruence_closure::get_eqv_proof(name const & R, expr const & e1, expr const & e2) const {
app_builder & b = get_app_builder();
if (has_expr_metavar(e1) || has_expr_metavar(e2)) return none_expr();
if (is_def_eq(e1, e2))
return some_expr(b.lift_from_eq(R, b.mk_eq_refl(e1)));
auto n1 = m_entries.find(eqc_key(R, e1));
if (!n1) return none_expr();
auto n2 = m_entries.find(eqc_key(R, e2));
if (!n2) return none_expr();
if (n1->m_root != n2->m_root) return none_expr();
// 1. Retrieve "path" from e1 to root
buffer<expr> path1, Hs1;
rb_tree<expr, expr_quick_cmp> visited;
expr it1 = e1;
while (true) {
visited.insert(it1);
auto it1_n = m_entries.find(eqc_key(R, it1));
lean_assert(it1_n);
if (!it1_n->m_target)
break;
path1.push_back(*it1_n->m_target);
Hs1.push_back(flip_proof(R, *it1_n->m_proof, it1_n->m_flipped));
it1 = *it1_n->m_target;
}
lean_assert(it1 == n1->m_root);
// 2. The path from e2 to root must have at least one element c in visited
// Retrieve "path" from e2 to c
buffer<expr> path2, Hs2;
path2.push_back(e2);
expr it2 = e2;
while (true) {
if (visited.contains(it2))
break; // found common
auto it2_n = m_entries.find(eqc_key(R, it2));
lean_assert(it2_n);
lean_assert(it2_n->m_target);
path2.push_back(it2);
Hs2.push_back(flip_proof(R, *it2_n->m_proof, !it2_n->m_flipped));
it2 = *it2_n->m_target;
}
// it2 is the common element...
// 3. Shink path1/Hs1 until we find it2 (the common element)
while (true) {
if (path1.empty()) {
lean_assert(it2 == e1 && e1 == n2->m_root);
break;
}
if (path1.back() == it2) {
// found it!
break;
}
path1.pop_back();
Hs1.pop_back();
}
// 4. Build transitivity proof
optional<expr> pr;
expr lhs = e1;
for (unsigned i = 0; i < path1.size(); i++) {
pr = mk_trans(R, pr, mk_proof(R, lhs, path1[i], Hs1[i]));
lhs = path1[i];
}
unsigned i = Hs2.size();
while (i > 0) {
--i;
pr = mk_trans(R, pr, mk_proof(R, lhs, path2[i], Hs2[i]));
lhs = path2[i];
}
lean_assert(pr);
return pr;
} }
bool congruence_closure::is_uneqv(name const & R, expr const & e1, expr const & e2) const { bool congruence_closure::is_uneqv(name const & R, expr const & e1, expr const & e2) const {

View file

@ -43,7 +43,9 @@ class congruence_closure {
// store 'target' at 'm_target', and 'H' at 'm_proof'. Both fields are none if 'e' == m_root // store 'target' at 'm_target', and 'H' at 'm_proof'. Both fields are none if 'e' == m_root
optional<expr> m_target; optional<expr> m_target;
optional<expr> m_proof; optional<expr> m_proof;
bool m_flipped; // proof has been flipped
unsigned m_size; // number of elements in the equivalence class, it is meaningless if 'e' != m_root unsigned m_size; // number of elements in the equivalence class, it is meaningless if 'e' != m_root
}; };
/* Key (R, e) for the mapping (R, e) -> entry */ /* Key (R, e) for the mapping (R, e) -> entry */
@ -97,6 +99,9 @@ class congruence_closure {
parents m_parents; parents m_parents;
congruences m_congruences; congruences m_congruences;
void internalize_core(name const & R, expr const & e);
void process_todo();
int compare_symm(name const & R, expr lhs1, expr rhs1, expr lhs2, expr rhs2) const; int compare_symm(name const & R, expr lhs1, expr rhs1, expr lhs2, expr rhs2) const;
int compare_root(name const & R, expr e1, expr e2) const; int compare_root(name const & R, expr e1, expr e2) const;
unsigned symm_hash(name const & R, expr const & lhs, expr const & rhs) const; unsigned symm_hash(name const & R, expr const & lhs, expr const & rhs) const;
@ -113,6 +118,8 @@ class congruence_closure {
void add_eqv_step(name const & R, expr e1, expr e2, expr const & H); void add_eqv_step(name const & R, expr e1, expr e2, expr const & H);
void add_eqv(name const & R, expr const & lhs, expr const & rhs, expr const & H); void add_eqv(name const & R, expr const & lhs, expr const & rhs, expr const & H);
expr mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H) const;
void display_eqc(name const & R, expr const & e) const; void display_eqc(name const & R, expr const & e) const;
public: public:
void initialize(); void initialize();

View file

@ -0,0 +1,19 @@
import data.list
constant f {A : Type} : A → A → A
constant g : nat → nat
set_option blast.subst false
example (a b c : nat) : a = b → g a == g b :=
by blast
example (a b c : nat) : a = b → c = b → f (f a b) (g c) = f (f c a) (g b) :=
by blast
example (a b c d e x y : nat) : a = b → a = x → b = y → c = d → c = e → c = b → a = e :=
by blast
open perm
example (a b c d : list nat) : a ~ b → c ~ b → d ~ c → a ~ d :=
by blast