feat(library/blast/congruence_closure): create simpler congruence proofs when using blast.cc.heq

This commit is contained in:
Leonardo de Moura 2016-01-10 15:11:31 -08:00
parent ea7da31bba
commit e9d24ec152
3 changed files with 135 additions and 50 deletions

View file

@ -79,21 +79,24 @@ ext_congr_lemma::ext_congr_lemma(congr_lemma const & H):
m_rel_names(rel_names_from_arg_kinds(H.get_arg_kinds(), get_eq_name())),
m_lift_needed(false),
m_fixed_fun(true),
m_heq_result(false) {}
ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed, bool heq_based):
m_heq_result(false),
m_hcongr_lemma(false) {}
ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed):
m_R(R),
m_congr_lemma(H),
m_rel_names(rel_names_from_arg_kinds(H.get_arg_kinds(), get_eq_name())),
m_lift_needed(lift_needed),
m_fixed_fun(true),
m_heq_result(heq_based) {}
m_heq_result(false),
m_hcongr_lemma(false) {}
ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, list<optional<name>> const & rel_names, bool lift_needed):
m_R(R),
m_congr_lemma(H),
m_rel_names(rel_names),
m_lift_needed(lift_needed),
m_fixed_fun(true),
m_heq_result(false) {}
m_heq_result(false),
m_hcongr_lemma(false) {}
/* We use the following cache for user-defined lemmas and automatically generated ones. */
typedef std::unordered_map<congr_lemma_key, optional<ext_congr_lemma>, congr_lemma_key_hash_fn, congr_lemma_key_eq_fn> congr_cache;
@ -381,12 +384,11 @@ static optional<ext_congr_lemma> mk_ext_specialized_congr_lemma(name const & R,
if (R == get_eq_name())
return optional<ext_congr_lemma>(res1);
bool lift_needed = true;
bool heq_result = false;
return optional<ext_congr_lemma>(R, *eq_congr, lift_needed, heq_result);
return optional<ext_congr_lemma>(R, *eq_congr, lift_needed);
}
/* Automatically generated congruence lemma based on heterogeneous equality. */
static optional<ext_congr_lemma> mk_hcongr_lemma(name const & R, expr const & fn, unsigned nargs) {
static optional<ext_congr_lemma> mk_hcongr_lemma_core(name const & R, expr const & fn, unsigned nargs) {
optional<congr_lemma> eq_congr = mk_hcongr_lemma(fn, nargs);
if (!eq_congr)
return optional<ext_congr_lemma>();
@ -398,6 +400,7 @@ static optional<ext_congr_lemma> mk_hcongr_lemma(name const & R, expr const & fn
res1.m_fixed_fun = false;
lean_assert(is_eq(type) || is_heq(type));
if (R == get_eq_name() || R == get_heq_name()) {
res1.m_hcongr_lemma = true;
if (is_heq(type))
res1.m_heq_result = true;
return optional<ext_congr_lemma>(res1);
@ -408,9 +411,10 @@ static optional<ext_congr_lemma> mk_hcongr_lemma(name const & R, expr const & fn
/* We cannot lift heterogeneous equality. */
return optional<ext_congr_lemma>();
} else {
bool heq_result = false;
bool lift_needed = true;
return optional<ext_congr_lemma>(R, *eq_congr, lift_needed, heq_result);
ext_congr_lemma res2(R, *eq_congr, lift_needed);
res2.m_hcongr_lemma = true;
return optional<ext_congr_lemma>(res2);
}
}
@ -429,7 +433,7 @@ optional<ext_congr_lemma> mk_ext_congr_lemma(name const & R, expr const & e) {
/* Try automatically generated lemma for equivalence relation over iff/eq */
if (!lemma) lemma = mk_relation_congr_lemma(R, fn, nargs);
/* Try automatically generated congruence lemma with support for heterogeneous equality. */
if (!lemma) lemma = mk_hcongr_lemma(R, fn, nargs);
if (!lemma) lemma = mk_hcongr_lemma_core(R, fn, nargs);
if (lemma) {
/* succeeded */
@ -475,6 +479,23 @@ optional<ext_congr_lemma> mk_ext_congr_lemma(name const & R, expr const & e) {
return optional<ext_congr_lemma>();
}
optional<ext_congr_lemma> mk_ext_hcongr_lemma(expr const & fn, unsigned nargs) {
congr_lemma_key key1(get_eq_name(), fn, nargs);
auto it1 = g_congr_cache->find(key1);
if (it1 != g_congr_cache->end())
return it1->second;
if (auto lemma = mk_hcongr_lemma_core(get_eq_name(), fn, nargs)) {
/* succeeded */
g_congr_cache->insert(mk_pair(key1, lemma));
return lemma;
}
/* cache failure */
g_congr_cache->insert(mk_pair(key1, optional<ext_congr_lemma>()));
return optional<ext_congr_lemma>();
}
void congruence_closure::update_non_eq_relations(name const & R) {
if (R == get_eq_name())
return;
@ -687,7 +708,12 @@ void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, exp
new_entry.m_cg_root = old_k->m_expr;
m_entries.insert(k, new_entry);
// 2. Put new equivalence in the TODO queue
bool heq_proof = lemma.m_heq_result;
bool heq_proof = false;
if (lemma.m_heq_result) {
lean_assert(g_heq_based);
if (!is_def_eq(infer_type(e), infer_type(old_k->m_expr)))
heq_proof = true;
}
push_todo(lemma.m_R, e, old_k->m_expr, *g_congr_mark, heq_proof);
} else {
m_congruences.insert(k);
@ -1159,47 +1185,96 @@ expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, e
auto lemma = mk_ext_congr_lemma(R, lhs);
lean_assert(lemma);
if (lemma->m_fixed_fun) {
list<optional<name>> const * it1 = &lemma->m_rel_names;
list<congr_arg_kind> const * it2 = &lemma->m_congr_lemma.get_arg_kinds();
buffer<expr> lemma_args;
for (unsigned i = 0; i < lhs_args.size(); i++) {
lean_assert(*it1 && *it2);
switch (head(*it2)) {
case congr_arg_kind::HEq:
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
break;
case congr_arg_kind::Eq:
lean_assert(head(*it1));
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i]));
break;
case congr_arg_kind::Fixed:
lemma_args.push_back(lhs_args[i]);
break;
case congr_arg_kind::FixedNoParam:
break;
case congr_arg_kind::Cast:
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
break;
if (g_heq_based && lemma->m_hcongr_lemma && (R == get_eq_name() || R == get_heq_name())) {
/* Try to simplify congruence proof by consuming common prefix of lhs and rhs */
/* This branch is an optimization, and it is not necessary */
unsigned i = 0;
for (; i < lhs_args.size(); i++) {
if (!is_def_eq(lhs_args[i], rhs_args[i]))
break;
}
it1 = &(tail(*it1));
it2 = &(tail(*it2));
unsigned prefix_sz = i;
unsigned rest_sz = lhs_args.size() - prefix_sz;
if (rest_sz == 0) {
if (heq_proofs)
return b.mk_heq_refl(lhs);
else
return b.mk_eq_refl(lhs);
}
expr g = lhs;
for (unsigned i = 0; i < rest_sz; i++) g = app_fn(g);
auto spec_lemma = mk_ext_hcongr_lemma(g, rest_sz);
lean_assert(spec_lemma);
list<congr_arg_kind> const * it = &spec_lemma->m_congr_lemma.get_arg_kinds();
buffer<expr> lemma_args;
for (unsigned i = prefix_sz; i < lhs_args.size(); i++) {
lean_assert(it);
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
if (head(*it) == congr_arg_kind::HEq) {
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
} else {
lean_assert(head(*it) == congr_arg_kind::Eq);
lemma_args.push_back(*get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i]));
}
it = &(tail(*it));
}
expr r = mk_app(spec_lemma->m_congr_lemma.get_proof(), lemma_args);
if (spec_lemma->m_heq_result && !heq_proofs)
r = b.mk_eq_of_heq(r);
else if (!spec_lemma->m_heq_result && heq_proofs)
r = b.mk_heq_of_eq(r);
return r;
} else {
/* Main case: convers user-defined congruence lemmas, and
all automatically generated congruence lemmas */
list<optional<name>> const * it1 = &lemma->m_rel_names;
list<congr_arg_kind> const * it2 = &lemma->m_congr_lemma.get_arg_kinds();
buffer<expr> lemma_args;
for (unsigned i = 0; i < lhs_args.size(); i++) {
lean_assert(*it1 && *it2);
switch (head(*it2)) {
case congr_arg_kind::HEq:
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
break;
case congr_arg_kind::Eq:
lean_assert(head(*it1));
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i]));
break;
case congr_arg_kind::Fixed:
lemma_args.push_back(lhs_args[i]);
break;
case congr_arg_kind::FixedNoParam:
break;
case congr_arg_kind::Cast:
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
break;
}
it1 = &(tail(*it1));
it2 = &(tail(*it2));
}
expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args);
if (lemma->m_lift_needed) {
lean_assert(!lemma->m_heq_result);
r = b.lift_from_eq(R, r);
}
if (lemma->m_heq_result && !heq_proofs)
r = b.mk_eq_of_heq(r);
else if (!lemma->m_heq_result && heq_proofs)
r = b.mk_heq_of_eq(r);
return r;
}
expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args);
if (lemma->m_lift_needed) {
lean_assert(!lemma->m_heq_result);
r = b.lift_from_eq(R, r);
}
if (lemma->m_heq_result && !heq_proofs)
r = b.mk_eq_of_heq(r);
else if (!lemma->m_heq_result && heq_proofs)
r = b.mk_heq_of_eq(r);
return r;
} else {
/* This branch builds congruence proofs that handle equality between functions.
The proof is created using congr_arg/congr_fun/congr lemmas.
It can build proofs for congruence such as:
f = g -> a = b -> f a = g b
but it is limited to simply typed functions. */
optional<expr> r;
unsigned i = 0;
if (!is_def_eq(lhs_fn, rhs_fn)) {

View file

@ -263,8 +263,10 @@ struct ext_congr_lemma {
unsigned m_fixed_fun:1;
/* If m_heq_result is true, then lemma is based on heterogeneous equality and the conclusion is a heterogeneous equality. */
unsigned m_heq_result:1;
/* If m_heq_lemma is true, then lemma was created using mk_hcongr_lemma. */
unsigned m_hcongr_lemma:1;
ext_congr_lemma(congr_lemma const & H);
ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed, bool heq_result);
ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed);
ext_congr_lemma(name const & R, congr_lemma const & H, list<optional<name>> const & rel_names, bool lift_needed);
name const & get_relation() const { return m_R; }

View file

@ -0,0 +1,8 @@
set_option blast.strategy "cc"
set_option blast.cc.heq true
definition ex1 (a b c a' b' c' : nat) : a = a' → b = b' → c = c' → a + b + c + a = a' + b' + c' + a' :=
by blast
set_option pp.beta true
print ex1