feat(library/blast/congruence_closure): create simpler congruence proofs when using blast.cc.heq
This commit is contained in:
parent
ea7da31bba
commit
e9d24ec152
3 changed files with 135 additions and 50 deletions
|
@ -79,21 +79,24 @@ ext_congr_lemma::ext_congr_lemma(congr_lemma const & H):
|
|||
m_rel_names(rel_names_from_arg_kinds(H.get_arg_kinds(), get_eq_name())),
|
||||
m_lift_needed(false),
|
||||
m_fixed_fun(true),
|
||||
m_heq_result(false) {}
|
||||
ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed, bool heq_based):
|
||||
m_heq_result(false),
|
||||
m_hcongr_lemma(false) {}
|
||||
ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed):
|
||||
m_R(R),
|
||||
m_congr_lemma(H),
|
||||
m_rel_names(rel_names_from_arg_kinds(H.get_arg_kinds(), get_eq_name())),
|
||||
m_lift_needed(lift_needed),
|
||||
m_fixed_fun(true),
|
||||
m_heq_result(heq_based) {}
|
||||
m_heq_result(false),
|
||||
m_hcongr_lemma(false) {}
|
||||
ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, list<optional<name>> const & rel_names, bool lift_needed):
|
||||
m_R(R),
|
||||
m_congr_lemma(H),
|
||||
m_rel_names(rel_names),
|
||||
m_lift_needed(lift_needed),
|
||||
m_fixed_fun(true),
|
||||
m_heq_result(false) {}
|
||||
m_heq_result(false),
|
||||
m_hcongr_lemma(false) {}
|
||||
|
||||
/* We use the following cache for user-defined lemmas and automatically generated ones. */
|
||||
typedef std::unordered_map<congr_lemma_key, optional<ext_congr_lemma>, congr_lemma_key_hash_fn, congr_lemma_key_eq_fn> congr_cache;
|
||||
|
@ -381,12 +384,11 @@ static optional<ext_congr_lemma> mk_ext_specialized_congr_lemma(name const & R,
|
|||
if (R == get_eq_name())
|
||||
return optional<ext_congr_lemma>(res1);
|
||||
bool lift_needed = true;
|
||||
bool heq_result = false;
|
||||
return optional<ext_congr_lemma>(R, *eq_congr, lift_needed, heq_result);
|
||||
return optional<ext_congr_lemma>(R, *eq_congr, lift_needed);
|
||||
}
|
||||
|
||||
/* Automatically generated congruence lemma based on heterogeneous equality. */
|
||||
static optional<ext_congr_lemma> mk_hcongr_lemma(name const & R, expr const & fn, unsigned nargs) {
|
||||
static optional<ext_congr_lemma> mk_hcongr_lemma_core(name const & R, expr const & fn, unsigned nargs) {
|
||||
optional<congr_lemma> eq_congr = mk_hcongr_lemma(fn, nargs);
|
||||
if (!eq_congr)
|
||||
return optional<ext_congr_lemma>();
|
||||
|
@ -398,6 +400,7 @@ static optional<ext_congr_lemma> mk_hcongr_lemma(name const & R, expr const & fn
|
|||
res1.m_fixed_fun = false;
|
||||
lean_assert(is_eq(type) || is_heq(type));
|
||||
if (R == get_eq_name() || R == get_heq_name()) {
|
||||
res1.m_hcongr_lemma = true;
|
||||
if (is_heq(type))
|
||||
res1.m_heq_result = true;
|
||||
return optional<ext_congr_lemma>(res1);
|
||||
|
@ -408,9 +411,10 @@ static optional<ext_congr_lemma> mk_hcongr_lemma(name const & R, expr const & fn
|
|||
/* We cannot lift heterogeneous equality. */
|
||||
return optional<ext_congr_lemma>();
|
||||
} else {
|
||||
bool heq_result = false;
|
||||
bool lift_needed = true;
|
||||
return optional<ext_congr_lemma>(R, *eq_congr, lift_needed, heq_result);
|
||||
ext_congr_lemma res2(R, *eq_congr, lift_needed);
|
||||
res2.m_hcongr_lemma = true;
|
||||
return optional<ext_congr_lemma>(res2);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -429,7 +433,7 @@ optional<ext_congr_lemma> mk_ext_congr_lemma(name const & R, expr const & e) {
|
|||
/* Try automatically generated lemma for equivalence relation over iff/eq */
|
||||
if (!lemma) lemma = mk_relation_congr_lemma(R, fn, nargs);
|
||||
/* Try automatically generated congruence lemma with support for heterogeneous equality. */
|
||||
if (!lemma) lemma = mk_hcongr_lemma(R, fn, nargs);
|
||||
if (!lemma) lemma = mk_hcongr_lemma_core(R, fn, nargs);
|
||||
|
||||
if (lemma) {
|
||||
/* succeeded */
|
||||
|
@ -475,6 +479,23 @@ optional<ext_congr_lemma> mk_ext_congr_lemma(name const & R, expr const & e) {
|
|||
return optional<ext_congr_lemma>();
|
||||
}
|
||||
|
||||
optional<ext_congr_lemma> mk_ext_hcongr_lemma(expr const & fn, unsigned nargs) {
|
||||
congr_lemma_key key1(get_eq_name(), fn, nargs);
|
||||
auto it1 = g_congr_cache->find(key1);
|
||||
if (it1 != g_congr_cache->end())
|
||||
return it1->second;
|
||||
|
||||
if (auto lemma = mk_hcongr_lemma_core(get_eq_name(), fn, nargs)) {
|
||||
/* succeeded */
|
||||
g_congr_cache->insert(mk_pair(key1, lemma));
|
||||
return lemma;
|
||||
}
|
||||
|
||||
/* cache failure */
|
||||
g_congr_cache->insert(mk_pair(key1, optional<ext_congr_lemma>()));
|
||||
return optional<ext_congr_lemma>();
|
||||
}
|
||||
|
||||
void congruence_closure::update_non_eq_relations(name const & R) {
|
||||
if (R == get_eq_name())
|
||||
return;
|
||||
|
@ -687,7 +708,12 @@ void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, exp
|
|||
new_entry.m_cg_root = old_k->m_expr;
|
||||
m_entries.insert(k, new_entry);
|
||||
// 2. Put new equivalence in the TODO queue
|
||||
bool heq_proof = lemma.m_heq_result;
|
||||
bool heq_proof = false;
|
||||
if (lemma.m_heq_result) {
|
||||
lean_assert(g_heq_based);
|
||||
if (!is_def_eq(infer_type(e), infer_type(old_k->m_expr)))
|
||||
heq_proof = true;
|
||||
}
|
||||
push_todo(lemma.m_R, e, old_k->m_expr, *g_congr_mark, heq_proof);
|
||||
} else {
|
||||
m_congruences.insert(k);
|
||||
|
@ -1159,47 +1185,96 @@ expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, e
|
|||
auto lemma = mk_ext_congr_lemma(R, lhs);
|
||||
lean_assert(lemma);
|
||||
if (lemma->m_fixed_fun) {
|
||||
list<optional<name>> const * it1 = &lemma->m_rel_names;
|
||||
list<congr_arg_kind> const * it2 = &lemma->m_congr_lemma.get_arg_kinds();
|
||||
buffer<expr> lemma_args;
|
||||
for (unsigned i = 0; i < lhs_args.size(); i++) {
|
||||
lean_assert(*it1 && *it2);
|
||||
switch (head(*it2)) {
|
||||
case congr_arg_kind::HEq:
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
|
||||
break;
|
||||
case congr_arg_kind::Eq:
|
||||
lean_assert(head(*it1));
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i]));
|
||||
break;
|
||||
case congr_arg_kind::Fixed:
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
break;
|
||||
case congr_arg_kind::FixedNoParam:
|
||||
break;
|
||||
case congr_arg_kind::Cast:
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
break;
|
||||
if (g_heq_based && lemma->m_hcongr_lemma && (R == get_eq_name() || R == get_heq_name())) {
|
||||
/* Try to simplify congruence proof by consuming common prefix of lhs and rhs */
|
||||
/* This branch is an optimization, and it is not necessary */
|
||||
unsigned i = 0;
|
||||
for (; i < lhs_args.size(); i++) {
|
||||
if (!is_def_eq(lhs_args[i], rhs_args[i]))
|
||||
break;
|
||||
}
|
||||
it1 = &(tail(*it1));
|
||||
it2 = &(tail(*it2));
|
||||
unsigned prefix_sz = i;
|
||||
unsigned rest_sz = lhs_args.size() - prefix_sz;
|
||||
if (rest_sz == 0) {
|
||||
if (heq_proofs)
|
||||
return b.mk_heq_refl(lhs);
|
||||
else
|
||||
return b.mk_eq_refl(lhs);
|
||||
}
|
||||
expr g = lhs;
|
||||
for (unsigned i = 0; i < rest_sz; i++) g = app_fn(g);
|
||||
auto spec_lemma = mk_ext_hcongr_lemma(g, rest_sz);
|
||||
lean_assert(spec_lemma);
|
||||
list<congr_arg_kind> const * it = &spec_lemma->m_congr_lemma.get_arg_kinds();
|
||||
buffer<expr> lemma_args;
|
||||
for (unsigned i = prefix_sz; i < lhs_args.size(); i++) {
|
||||
lean_assert(it);
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
if (head(*it) == congr_arg_kind::HEq) {
|
||||
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
|
||||
} else {
|
||||
lean_assert(head(*it) == congr_arg_kind::Eq);
|
||||
lemma_args.push_back(*get_eqv_proof(get_eq_name(), lhs_args[i], rhs_args[i]));
|
||||
}
|
||||
it = &(tail(*it));
|
||||
}
|
||||
expr r = mk_app(spec_lemma->m_congr_lemma.get_proof(), lemma_args);
|
||||
if (spec_lemma->m_heq_result && !heq_proofs)
|
||||
r = b.mk_eq_of_heq(r);
|
||||
else if (!spec_lemma->m_heq_result && heq_proofs)
|
||||
r = b.mk_heq_of_eq(r);
|
||||
return r;
|
||||
} else {
|
||||
/* Main case: convers user-defined congruence lemmas, and
|
||||
all automatically generated congruence lemmas */
|
||||
list<optional<name>> const * it1 = &lemma->m_rel_names;
|
||||
list<congr_arg_kind> const * it2 = &lemma->m_congr_lemma.get_arg_kinds();
|
||||
buffer<expr> lemma_args;
|
||||
for (unsigned i = 0; i < lhs_args.size(); i++) {
|
||||
lean_assert(*it1 && *it2);
|
||||
switch (head(*it2)) {
|
||||
case congr_arg_kind::HEq:
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
|
||||
break;
|
||||
case congr_arg_kind::Eq:
|
||||
lean_assert(head(*it1));
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
lemma_args.push_back(*get_eqv_proof(*head(*it1), lhs_args[i], rhs_args[i]));
|
||||
break;
|
||||
case congr_arg_kind::Fixed:
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
break;
|
||||
case congr_arg_kind::FixedNoParam:
|
||||
break;
|
||||
case congr_arg_kind::Cast:
|
||||
lemma_args.push_back(lhs_args[i]);
|
||||
lemma_args.push_back(rhs_args[i]);
|
||||
break;
|
||||
}
|
||||
it1 = &(tail(*it1));
|
||||
it2 = &(tail(*it2));
|
||||
}
|
||||
expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args);
|
||||
if (lemma->m_lift_needed) {
|
||||
lean_assert(!lemma->m_heq_result);
|
||||
r = b.lift_from_eq(R, r);
|
||||
}
|
||||
if (lemma->m_heq_result && !heq_proofs)
|
||||
r = b.mk_eq_of_heq(r);
|
||||
else if (!lemma->m_heq_result && heq_proofs)
|
||||
r = b.mk_heq_of_eq(r);
|
||||
return r;
|
||||
}
|
||||
expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args);
|
||||
if (lemma->m_lift_needed) {
|
||||
lean_assert(!lemma->m_heq_result);
|
||||
r = b.lift_from_eq(R, r);
|
||||
}
|
||||
if (lemma->m_heq_result && !heq_proofs)
|
||||
r = b.mk_eq_of_heq(r);
|
||||
else if (!lemma->m_heq_result && heq_proofs)
|
||||
r = b.mk_heq_of_eq(r);
|
||||
return r;
|
||||
} else {
|
||||
/* This branch builds congruence proofs that handle equality between functions.
|
||||
The proof is created using congr_arg/congr_fun/congr lemmas.
|
||||
It can build proofs for congruence such as:
|
||||
f = g -> a = b -> f a = g b
|
||||
but it is limited to simply typed functions. */
|
||||
optional<expr> r;
|
||||
unsigned i = 0;
|
||||
if (!is_def_eq(lhs_fn, rhs_fn)) {
|
||||
|
|
|
@ -263,8 +263,10 @@ struct ext_congr_lemma {
|
|||
unsigned m_fixed_fun:1;
|
||||
/* If m_heq_result is true, then lemma is based on heterogeneous equality and the conclusion is a heterogeneous equality. */
|
||||
unsigned m_heq_result:1;
|
||||
/* If m_heq_lemma is true, then lemma was created using mk_hcongr_lemma. */
|
||||
unsigned m_hcongr_lemma:1;
|
||||
ext_congr_lemma(congr_lemma const & H);
|
||||
ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed, bool heq_result);
|
||||
ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed);
|
||||
ext_congr_lemma(name const & R, congr_lemma const & H, list<optional<name>> const & rel_names, bool lift_needed);
|
||||
|
||||
name const & get_relation() const { return m_R; }
|
||||
|
|
8
tests/lean/run/blast_cc_heq5.lean
Normal file
8
tests/lean/run/blast_cc_heq5.lean
Normal file
|
@ -0,0 +1,8 @@
|
|||
set_option blast.strategy "cc"
|
||||
set_option blast.cc.heq true
|
||||
|
||||
definition ex1 (a b c a' b' c' : nat) : a = a' → b = b' → c = c' → a + b + c + a = a' + b' + c' + a' :=
|
||||
by blast
|
||||
|
||||
set_option pp.beta true
|
||||
print ex1
|
Loading…
Reference in a new issue