feat(library/blast/congruence_closure): proof extraction
This commit is contained in:
parent
1e3f549c39
commit
28970ef717
4 changed files with 223 additions and 22 deletions
|
@ -33,10 +33,10 @@ static congruence_closure & get_cc() {
|
|||
|
||||
action_result assert_cc_action(hypothesis_idx hidx) {
|
||||
congruence_closure & cc = get_cc();
|
||||
// TODO(Leo): consider a target_changed event for branch_extension.
|
||||
cc.internalize(curr_state().get_target());
|
||||
cc.add(hidx);
|
||||
// TODO(Leo): remove the following line
|
||||
return action_result::new_branch();
|
||||
cc.display();
|
||||
// cc.display();
|
||||
if (cc.is_inconsistent()) {
|
||||
try {
|
||||
app_builder & b = get_app_builder();
|
||||
|
|
|
@ -96,6 +96,7 @@ void congruence_closure::mk_entry_for(name const & R, expr const & e) {
|
|||
n.m_root = e;
|
||||
n.m_cg_root = e;
|
||||
n.m_size = 1;
|
||||
n.m_flipped = false;
|
||||
m_entries.insert(eqc_key(R, e), n);
|
||||
}
|
||||
|
||||
|
@ -373,7 +374,7 @@ void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, exp
|
|||
check_iff_true(k);
|
||||
}
|
||||
|
||||
void congruence_closure::internalize(name const & R, expr const & e) {
|
||||
void congruence_closure::internalize_core(name const & R, expr const & e) {
|
||||
lean_assert(closed(e));
|
||||
if (has_expr_metavar(e))
|
||||
return;
|
||||
|
@ -390,13 +391,13 @@ void congruence_closure::internalize(name const & R, expr const & e) {
|
|||
return;
|
||||
case expr_kind::Macro:
|
||||
for (unsigned i = 0; i < macro_num_args(e); i++)
|
||||
internalize(R, macro_arg(e, i));
|
||||
internalize_core(R, macro_arg(e, i));
|
||||
mk_entry_for(R, e);
|
||||
break;
|
||||
case expr_kind::Pi:
|
||||
if (is_arrow(e) && is_prop(binding_domain(e)) && is_prop(binding_body(e))) {
|
||||
internalize(R, binding_domain(e));
|
||||
internalize(R, binding_body(e));
|
||||
internalize_core(R, binding_domain(e));
|
||||
internalize_core(R, binding_body(e));
|
||||
}
|
||||
if (is_prop(e)) {
|
||||
mk_entry_for(R, e);
|
||||
|
@ -411,13 +412,13 @@ void congruence_closure::internalize(name const & R, expr const & e) {
|
|||
for (expr const & arg : args) {
|
||||
lean_assert(*it);
|
||||
if (auto R1 = head(*it)) {
|
||||
internalize(*R1, arg);
|
||||
internalize_core(*R1, arg);
|
||||
add_occurrence(R, e, *R1, arg);
|
||||
}
|
||||
it = &tail(*it);
|
||||
}
|
||||
if (!lemma->m_fixed_fun) {
|
||||
internalize(get_eq_name(), fn);
|
||||
internalize_core(get_eq_name(), fn);
|
||||
add_occurrence(get_eq_name(), e, get_eq_name(), fn);
|
||||
}
|
||||
add_congruence_table(*lemma, e);
|
||||
|
@ -426,6 +427,12 @@ void congruence_closure::internalize(name const & R, expr const & e) {
|
|||
}}
|
||||
}
|
||||
|
||||
void congruence_closure::internalize(name const & R, expr const & e) {
|
||||
flet<congruence_closure *> set_cc(g_cc, this);
|
||||
internalize_core(R, e);
|
||||
process_todo();
|
||||
}
|
||||
|
||||
void congruence_closure::internalize(expr const & e) {
|
||||
if (is_prop(e))
|
||||
internalize(get_iff_name(), e);
|
||||
|
@ -453,8 +460,9 @@ void congruence_closure::invert_trans(name const & R, expr const & e, optional<e
|
|||
entry new_n = *n;
|
||||
if (n->m_target)
|
||||
invert_trans(R, *new_n.m_target, some_expr(e), new_n.m_proof);
|
||||
new_n.m_target = new_target;
|
||||
new_n.m_proof = new_proof;
|
||||
new_n.m_target = new_target;
|
||||
new_n.m_proof = new_proof;
|
||||
new_n.m_flipped = !new_n.m_flipped;
|
||||
m_entries.insert(k, new_n);
|
||||
}
|
||||
void congruence_closure::invert_trans(name const & R, expr const & e) {
|
||||
|
@ -496,6 +504,7 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
|
|||
auto r1 = m_entries.find(eqc_key(R, n1->m_root));
|
||||
auto r2 = m_entries.find(eqc_key(R, n2->m_root));
|
||||
lean_assert(r1 && r2);
|
||||
bool flipped = false;
|
||||
|
||||
// We want r2 to be the root of the combined class.
|
||||
|
||||
|
@ -504,6 +513,7 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
|
|||
std::swap(n1, n2);
|
||||
std::swap(r1, r2);
|
||||
// Remark: we don't apply symmetry eagerly. So, we don't adjust H.
|
||||
flipped = true;
|
||||
}
|
||||
|
||||
expr e1_root = n1->m_root;
|
||||
|
@ -516,8 +526,9 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
|
|||
// We want
|
||||
// r1 -> ... -> e1 -> e2 -> ... -> r2
|
||||
invert_trans(R, e1);
|
||||
new_n1.m_target = e2;
|
||||
new_n1.m_proof = H;
|
||||
new_n1.m_target = e2;
|
||||
new_n1.m_proof = H;
|
||||
new_n1.m_flipped = flipped;
|
||||
m_entries.insert(eqc_key(R, e1), new_n1);
|
||||
|
||||
// The hash code for the parents is going to change
|
||||
|
@ -565,9 +576,8 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
|
|||
m_parents.insert(k2, ps2);
|
||||
}
|
||||
|
||||
void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const & _rhs, expr const & _H) {
|
||||
void congruence_closure::process_todo() {
|
||||
auto & todo = get_todo();
|
||||
todo.emplace_back(_R, _lhs, _rhs, _H);
|
||||
while (!todo.empty()) {
|
||||
name R; expr lhs, rhs, H;
|
||||
std::tie(R, lhs, rhs, H) = todo.back();
|
||||
|
@ -576,6 +586,11 @@ void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const
|
|||
}
|
||||
}
|
||||
|
||||
void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const & _rhs, expr const & _H) {
|
||||
get_todo().emplace_back(_R, _lhs, _rhs, _H);
|
||||
process_todo();
|
||||
}
|
||||
|
||||
void congruence_closure::add(hypothesis_idx hidx) {
|
||||
if (is_inconsistent())
|
||||
return;
|
||||
|
@ -593,15 +608,15 @@ void congruence_closure::add(hypothesis_idx hidx) {
|
|||
name R; expr lhs, rhs;
|
||||
if (is_relation_app(p, R, lhs, rhs)) {
|
||||
if (is_neg) {
|
||||
internalize(get_iff_name(), p);
|
||||
internalize_core(get_iff_name(), p);
|
||||
add_eqv(get_iff_name(), p, mk_false(), b.mk_iff_false_intro(h.get_self()));
|
||||
} else {
|
||||
internalize(R, lhs);
|
||||
internalize(R, rhs);
|
||||
internalize_core(R, lhs);
|
||||
internalize_core(R, rhs);
|
||||
add_eqv(R, lhs, rhs, h.get_self());
|
||||
}
|
||||
} else if (is_prop(p)) {
|
||||
internalize(get_iff_name(), p);
|
||||
internalize_core(get_iff_name(), p);
|
||||
if (is_neg) {
|
||||
add_eqv(get_iff_name(), p, mk_false(), b.mk_iff_false_intro(h.get_self()));
|
||||
} else {
|
||||
|
@ -619,10 +634,170 @@ bool congruence_closure::is_eqv(name const & R, expr const & e1, expr const & e2
|
|||
return n1->m_root == n2->m_root;
|
||||
}
|
||||
|
||||
expr congruence_closure::mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H) const {
|
||||
if (H == *g_congr_mark) {
|
||||
app_builder & b = get_app_builder();
|
||||
buffer<expr> lhs_args, rhs_args;
|
||||
expr const & lhs_fn = get_app_args(lhs, lhs_args);
|
||||
expr const & rhs_fn = get_app_args(rhs, rhs_args);
|
||||
lean_assert(lhs_args.size() == rhs_args.size());
|
||||
auto lemma = mk_ext_congr_lemma(R, lhs_fn, lhs_args.size());
|
||||
lean_assert(lemma);
|
||||
if (lemma->m_fixed_fun) {
|
||||
list<optional<name>> const * it1 = &lemma->m_rel_names;
|
||||
list<congr_arg_kind> const * it2 = &lemma->m_congr_lemma.get_arg_kinds();
|
||||
buffer<expr> lemma_args;
|
||||
for (unsigned i = 0; i < lhs_args.size(); i++) {
|
||||
lean_assert(*it1 && *it2);
|
||||
switch (head(*it2)) {
|
||||
case congr_arg_kind::Eq:
|
||||
lean_assert(head(*it1));
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i]));
|
||||
break;
|
||||
case congr_arg_kind::Fixed:
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
break;
|
||||
case congr_arg_kind::Cast:
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
break;
|
||||
}
|
||||
it1 = &(tail(*it1));
|
||||
it2 = &(tail(*it2));
|
||||
}
|
||||
expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args);
|
||||
if (lemma->m_lift_needed)
|
||||
r = b.lift_from_eq(R, r);
|
||||
return r;
|
||||
} else {
|
||||
optional<expr> r;
|
||||
unsigned i = 0;
|
||||
if (!is_def_eq(lhs_fn, rhs_fn)) {
|
||||
r = get_eqv_proof(get_eq_name(), lhs_fn, rhs_fn);
|
||||
} else {
|
||||
for (; i < lhs_args.size(); i++) {
|
||||
if (!is_def_eq(lhs_args[i], rhs_args[i])) {
|
||||
expr g = mk_app(lhs_fn, i, lhs_args.data());
|
||||
expr Hi = *get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i]);
|
||||
r = b.mk_congr_arg(g, Hi);
|
||||
i++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!r) {
|
||||
// lhs and rhs are definitionally equal
|
||||
r = b.mk_eq_refl(lhs);
|
||||
}
|
||||
}
|
||||
lean_assert(r);
|
||||
for (; i < lhs_args.size(); i++) {
|
||||
if (is_def_eq(lhs_args[i], rhs_args[i])) {
|
||||
r = b.mk_congr_fun(*r, lhs_args[i]);
|
||||
} else {
|
||||
expr Hi = *get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i]);
|
||||
r = b.mk_congr(*r, Hi);
|
||||
}
|
||||
}
|
||||
if (lemma->m_lift_needed)
|
||||
r = b.lift_from_eq(R, *r);
|
||||
return *r;
|
||||
}
|
||||
} else if (H == *g_iff_true_mark) {
|
||||
// TODO(Leo):
|
||||
lean_unreachable();
|
||||
} else {
|
||||
return H;
|
||||
}
|
||||
}
|
||||
|
||||
static expr flip_proof(name const & R, expr const & H, bool flipped) {
|
||||
if (!flipped) {
|
||||
return H;
|
||||
} else if (H == *g_congr_mark) {
|
||||
return H;
|
||||
} else if (H == *g_iff_true_mark) {
|
||||
return H;
|
||||
} else {
|
||||
return get_app_builder().mk_symm(R, H);
|
||||
}
|
||||
}
|
||||
|
||||
static expr mk_trans(name const & R, optional<expr> const & H1, expr const & H2) {
|
||||
return !H1 ? H2 : get_app_builder().mk_trans(R, *H1, H2);
|
||||
}
|
||||
|
||||
optional<expr> congruence_closure::get_eqv_proof(name const & R, expr const & e1, expr const & e2) const {
|
||||
// TODO(Leo):
|
||||
std::cout << R << e1 << e2 << "\n";
|
||||
return none_expr();
|
||||
app_builder & b = get_app_builder();
|
||||
if (has_expr_metavar(e1) || has_expr_metavar(e2)) return none_expr();
|
||||
if (is_def_eq(e1, e2))
|
||||
return some_expr(b.lift_from_eq(R, b.mk_eq_refl(e1)));
|
||||
auto n1 = m_entries.find(eqc_key(R, e1));
|
||||
if (!n1) return none_expr();
|
||||
auto n2 = m_entries.find(eqc_key(R, e2));
|
||||
if (!n2) return none_expr();
|
||||
if (n1->m_root != n2->m_root) return none_expr();
|
||||
// 1. Retrieve "path" from e1 to root
|
||||
buffer<expr> path1, Hs1;
|
||||
rb_tree<expr, expr_quick_cmp> visited;
|
||||
expr it1 = e1;
|
||||
while (true) {
|
||||
visited.insert(it1);
|
||||
auto it1_n = m_entries.find(eqc_key(R, it1));
|
||||
lean_assert(it1_n);
|
||||
if (!it1_n->m_target)
|
||||
break;
|
||||
path1.push_back(*it1_n->m_target);
|
||||
Hs1.push_back(flip_proof(R, *it1_n->m_proof, it1_n->m_flipped));
|
||||
it1 = *it1_n->m_target;
|
||||
}
|
||||
lean_assert(it1 == n1->m_root);
|
||||
// 2. The path from e2 to root must have at least one element c in visited
|
||||
// Retrieve "path" from e2 to c
|
||||
buffer<expr> path2, Hs2;
|
||||
path2.push_back(e2);
|
||||
expr it2 = e2;
|
||||
while (true) {
|
||||
if (visited.contains(it2))
|
||||
break; // found common
|
||||
auto it2_n = m_entries.find(eqc_key(R, it2));
|
||||
lean_assert(it2_n);
|
||||
lean_assert(it2_n->m_target);
|
||||
path2.push_back(it2);
|
||||
Hs2.push_back(flip_proof(R, *it2_n->m_proof, !it2_n->m_flipped));
|
||||
it2 = *it2_n->m_target;
|
||||
}
|
||||
// it2 is the common element...
|
||||
// 3. Shink path1/Hs1 until we find it2 (the common element)
|
||||
while (true) {
|
||||
if (path1.empty()) {
|
||||
lean_assert(it2 == e1 && e1 == n2->m_root);
|
||||
break;
|
||||
}
|
||||
if (path1.back() == it2) {
|
||||
// found it!
|
||||
break;
|
||||
}
|
||||
path1.pop_back();
|
||||
Hs1.pop_back();
|
||||
}
|
||||
|
||||
// 4. Build transitivity proof
|
||||
optional<expr> pr;
|
||||
expr lhs = e1;
|
||||
for (unsigned i = 0; i < path1.size(); i++) {
|
||||
pr = mk_trans(R, pr, mk_proof(R, lhs, path1[i], Hs1[i]));
|
||||
lhs = path1[i];
|
||||
}
|
||||
unsigned i = Hs2.size();
|
||||
while (i > 0) {
|
||||
--i;
|
||||
pr = mk_trans(R, pr, mk_proof(R, lhs, path2[i], Hs2[i]));
|
||||
lhs = path2[i];
|
||||
}
|
||||
lean_assert(pr);
|
||||
return pr;
|
||||
}
|
||||
|
||||
bool congruence_closure::is_uneqv(name const & R, expr const & e1, expr const & e2) const {
|
||||
|
|
|
@ -43,7 +43,9 @@ class congruence_closure {
|
|||
// store 'target' at 'm_target', and 'H' at 'm_proof'. Both fields are none if 'e' == m_root
|
||||
optional<expr> m_target;
|
||||
optional<expr> m_proof;
|
||||
bool m_flipped; // proof has been flipped
|
||||
unsigned m_size; // number of elements in the equivalence class, it is meaningless if 'e' != m_root
|
||||
|
||||
};
|
||||
|
||||
/* Key (R, e) for the mapping (R, e) -> entry */
|
||||
|
@ -97,6 +99,9 @@ class congruence_closure {
|
|||
parents m_parents;
|
||||
congruences m_congruences;
|
||||
|
||||
void internalize_core(name const & R, expr const & e);
|
||||
void process_todo();
|
||||
|
||||
int compare_symm(name const & R, expr lhs1, expr rhs1, expr lhs2, expr rhs2) const;
|
||||
int compare_root(name const & R, expr e1, expr e2) const;
|
||||
unsigned symm_hash(name const & R, expr const & lhs, expr const & rhs) const;
|
||||
|
@ -113,6 +118,8 @@ class congruence_closure {
|
|||
void add_eqv_step(name const & R, expr e1, expr e2, expr const & H);
|
||||
void add_eqv(name const & R, expr const & lhs, expr const & rhs, expr const & H);
|
||||
|
||||
expr mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H) const;
|
||||
|
||||
void display_eqc(name const & R, expr const & e) const;
|
||||
public:
|
||||
void initialize();
|
||||
|
|
19
tests/lean/run/blast_cc1.lean
Normal file
19
tests/lean/run/blast_cc1.lean
Normal file
|
@ -0,0 +1,19 @@
|
|||
import data.list
|
||||
|
||||
constant f {A : Type} : A → A → A
|
||||
constant g : nat → nat
|
||||
set_option blast.subst false
|
||||
|
||||
example (a b c : nat) : a = b → g a == g b :=
|
||||
by blast
|
||||
|
||||
example (a b c : nat) : a = b → c = b → f (f a b) (g c) = f (f c a) (g b) :=
|
||||
by blast
|
||||
|
||||
example (a b c d e x y : nat) : a = b → a = x → b = y → c = d → c = e → c = b → a = e :=
|
||||
by blast
|
||||
|
||||
open perm
|
||||
|
||||
example (a b c d : list nat) : a ~ b → c ~ b → d ~ c → a ~ d :=
|
||||
by blast
|
Loading…
Reference in a new issue