feat(library/congr_lemma_manager): add congruence lemma for normalization and congruence closure
This commit is contained in:
parent
15cd92f89f
commit
1d1cd0fc24
2 changed files with 150 additions and 4 deletions
|
@ -8,6 +8,7 @@ Author: Leonardo de Moura
|
|||
#include "kernel/abstract.h"
|
||||
#include "library/util.h"
|
||||
#include "library/locals.h"
|
||||
#include "library/constants.h"
|
||||
#include "library/replace_visitor.h"
|
||||
#include "library/congr_lemma_manager.h"
|
||||
|
||||
|
@ -35,6 +36,7 @@ class congr_lemma_manager::imp {
|
|||
}
|
||||
};
|
||||
|
||||
std::unordered_map<key, result, key_hash_fn, key_eq_fn> m_simp_cache;
|
||||
std::unordered_map<key, result, key_hash_fn, key_eq_fn> m_cache;
|
||||
|
||||
expr infer(expr const & e) { return m_ctx.infer(e); }
|
||||
|
@ -192,6 +194,100 @@ class congr_lemma_manager::imp {
|
|||
}
|
||||
}
|
||||
|
||||
optional<result> mk_congr(expr const & fn, optional<result> const & simp_lemma,
|
||||
buffer<param_info> const & pinfos, buffer<congr_arg_kind> const & kinds) {
|
||||
try {
|
||||
expr fn_type1 = whnf(infer(fn));
|
||||
expr fn_type2 = fn_type1;
|
||||
name e_name("e");
|
||||
buffer<expr> lhss;
|
||||
buffer<expr> rhss; // it contains the right-hand-side argument
|
||||
buffer<optional<expr>> eqs; // for Eq args, it contains the equality
|
||||
buffer<expr> hyps; // contains lhss + rhss + eqs
|
||||
buffer<expr> simp_lemma_args;
|
||||
for (unsigned i = 0; i < pinfos.size(); i++) {
|
||||
if (!is_pi(fn_type1)) {
|
||||
return optional<result>();
|
||||
}
|
||||
expr lhs = m_ctx.mk_tmp_local(binding_name(fn_type1), binding_domain(fn_type1));
|
||||
expr rhs;
|
||||
lhss.push_back(lhs);
|
||||
hyps.push_back(lhs);
|
||||
simp_lemma_args.push_back(lhs);
|
||||
switch (kinds[i]) {
|
||||
case congr_arg_kind::Eq: {
|
||||
lean_assert(m_ctx.is_def_eq(binding_domain(fn_type1), binding_domain(fn_type2)));
|
||||
rhs = m_ctx.mk_tmp_local(binding_name(fn_type2), binding_domain(fn_type2));
|
||||
expr eq_type = m_builder.mk_eq(lhs, rhs);
|
||||
rhss.push_back(rhs);
|
||||
expr eq = m_ctx.mk_tmp_local(e_name.append_after(eqs.size()+1), eq_type);
|
||||
eqs.push_back(some_expr(eq));
|
||||
hyps.push_back(rhs);
|
||||
hyps.push_back(eq);
|
||||
simp_lemma_args.push_back(rhs);
|
||||
simp_lemma_args.push_back(eq);
|
||||
break;
|
||||
}
|
||||
case congr_arg_kind::Fixed:
|
||||
rhs = lhs;
|
||||
rhss.push_back(rhs);
|
||||
eqs.push_back(none_expr());
|
||||
break;
|
||||
case congr_arg_kind::Cast: {
|
||||
rhs = m_ctx.mk_tmp_local(binding_name(fn_type2), binding_domain(fn_type2));
|
||||
rhss.push_back(rhs);
|
||||
eqs.push_back(none_expr());
|
||||
hyps.push_back(rhs);
|
||||
break;
|
||||
}}
|
||||
fn_type1 = whnf(instantiate(binding_body(fn_type1), lhs));
|
||||
fn_type2 = whnf(instantiate(binding_body(fn_type2), rhs));
|
||||
}
|
||||
expr pr1 = mk_app(simp_lemma->get_proof(), simp_lemma_args);
|
||||
expr type1 = simp_lemma->get_type();
|
||||
while (is_pi(type1))
|
||||
type1 = binding_body(type1);
|
||||
type1 = instantiate_rev(type1, simp_lemma_args.size(), simp_lemma_args.data());
|
||||
expr lhs1, rhs1;
|
||||
lean_verify(is_eq(type1, lhs1, rhs1));
|
||||
// build proof2
|
||||
expr rhs2 = mk_app(fn, rhss);
|
||||
expr eq = m_builder.mk_eq(lhs1, rhs2);
|
||||
expr congr_type = Pi(hyps, eq);
|
||||
// build proof that rhs1 = rhs2
|
||||
unsigned i;
|
||||
for (i = 0; i < kinds.size(); i++) {
|
||||
if (kinds[i] == congr_arg_kind::Cast && !pinfos[i].is_prop())
|
||||
break;
|
||||
}
|
||||
if (i == kinds.size()) {
|
||||
// rhs1 and rhs2 are definitionally equal
|
||||
expr congr_proof = Fun(hyps, pr1);
|
||||
return optional<result>(congr_type, congr_proof, to_list(kinds));
|
||||
}
|
||||
buffer<expr> rhss1;
|
||||
get_app_args(rhs1, rhss1);
|
||||
lean_assert(rhss.size() == rhss1.size());
|
||||
expr a = mk_app(fn, i, rhss1.data());
|
||||
expr pr2 = m_builder.mk_eq_refl(a);
|
||||
for (; i < kinds.size(); i++) {
|
||||
if (kinds[i] == congr_arg_kind::Cast && !pinfos[i].is_prop()) {
|
||||
lean_assert(pinfos[i].is_subsingleton());
|
||||
expr r1 = rhss1[i];
|
||||
expr r2 = rhss[i];
|
||||
expr r1_eq_r2 = m_builder.mk_app(get_subsingleton_elim_name(), r1, r2);
|
||||
pr2 = m_builder.mk_congr(pr2, r1_eq_r2);
|
||||
} else {
|
||||
pr2 = m_builder.mk_congr_fun(pr2, rhss[i]);
|
||||
}
|
||||
}
|
||||
expr congr_proof = Fun(hyps, m_builder.mk_eq_trans(pr1, pr2));
|
||||
return optional<result>(congr_type, congr_proof, to_list(kinds));
|
||||
} catch (app_builder_exception &) {
|
||||
return optional<result>();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
imp(app_builder & b, fun_info_manager & fm, bool ignore_inst_implicit):
|
||||
m_builder(b), m_fmanager(fm), m_ctx(fm.ctx()), m_ignore_inst_implicit(ignore_inst_implicit) {}
|
||||
|
@ -202,8 +298,8 @@ public:
|
|||
}
|
||||
|
||||
optional<result> mk_congr_simp(expr const & fn, unsigned nargs) {
|
||||
auto r = m_cache.find(key(fn, nargs));
|
||||
if (r != m_cache.end())
|
||||
auto r = m_simp_cache.find(key(fn, nargs));
|
||||
if (r != m_simp_cache.end())
|
||||
return optional<result>(r->second);
|
||||
fun_info finfo = m_fmanager.get(fn, nargs);
|
||||
list<unsigned> const & result_deps = finfo.get_dependencies();
|
||||
|
@ -238,7 +334,7 @@ public:
|
|||
}
|
||||
auto new_r = mk_congr_simp(fn, pinfos, kinds);
|
||||
if (new_r) {
|
||||
m_cache.insert(mk_pair(key(fn, nargs), *new_r));
|
||||
m_simp_cache.insert(mk_pair(key(fn, nargs), *new_r));
|
||||
return new_r;
|
||||
} else if (has_cast(kinds)) {
|
||||
// remove casts and try again
|
||||
|
@ -248,7 +344,7 @@ public:
|
|||
}
|
||||
auto new_r = mk_congr_simp(fn, pinfos, kinds);
|
||||
if (new_r) {
|
||||
m_cache.insert(mk_pair(key(fn, nargs), *new_r));
|
||||
m_simp_cache.insert(mk_pair(key(fn, nargs), *new_r));
|
||||
return new_r;
|
||||
} else {
|
||||
return new_r;
|
||||
|
@ -257,6 +353,47 @@ public:
|
|||
return new_r;
|
||||
}
|
||||
}
|
||||
|
||||
optional<result> mk_congr(expr const & fn) {
|
||||
fun_info finfo = m_fmanager.get(fn);
|
||||
return mk_congr(fn, finfo.get_arity());
|
||||
}
|
||||
|
||||
optional<result> mk_congr(expr const & fn, unsigned nargs) {
|
||||
auto r = m_cache.find(key(fn, nargs));
|
||||
if (r != m_cache.end())
|
||||
return optional<result>(r->second);
|
||||
fun_info finfo = m_fmanager.get(fn, nargs);
|
||||
optional<result> simp_lemma = mk_congr_simp(fn, nargs);
|
||||
if (!simp_lemma)
|
||||
return optional<result>();
|
||||
buffer<congr_arg_kind> kinds;
|
||||
buffer<param_info> pinfos;
|
||||
to_buffer(simp_lemma->get_arg_kinds(), kinds);
|
||||
to_buffer(finfo.get_params_info(), pinfos);
|
||||
// For congr lemmas we have the following restriction:
|
||||
// if a Cast arg is subsingleton, it is not a proposition,
|
||||
// and it is a dependent argument, then we mark it as fixed.
|
||||
// This restriction doesn't affect the standard library,
|
||||
// but it simplifies the implementation.
|
||||
lean_assert(kinds.size() == pinfos.size());
|
||||
bool has_cast = false;
|
||||
for (unsigned i = 0; i < kinds.size(); i++) {
|
||||
if (!pinfos[i].is_prop() && pinfos[i].is_subsingleton() && pinfos[i].is_dep()) {
|
||||
kinds[i] = congr_arg_kind::Fixed;
|
||||
}
|
||||
if (kinds[i] == congr_arg_kind::Cast)
|
||||
has_cast = true;
|
||||
}
|
||||
if (!has_cast) {
|
||||
m_cache.insert(mk_pair(key(fn, nargs), *simp_lemma));
|
||||
return simp_lemma; // simp_lemma will be identical to regular congr lemma
|
||||
}
|
||||
auto new_r = mk_congr(fn, simp_lemma, pinfos, kinds);
|
||||
if (new_r)
|
||||
m_cache.insert(mk_pair(key(fn, nargs), *new_r));
|
||||
return new_r;
|
||||
}
|
||||
};
|
||||
|
||||
congr_lemma_manager::congr_lemma_manager(app_builder & b, fun_info_manager & fm, bool ignore_inst_implicit):
|
||||
|
@ -272,4 +409,10 @@ auto congr_lemma_manager::mk_congr_simp(expr const & fn) -> optional<result> {
|
|||
auto congr_lemma_manager::mk_congr_simp(expr const & fn, unsigned nargs) -> optional<result> {
|
||||
return m_ptr->mk_congr_simp(fn, nargs);
|
||||
}
|
||||
auto congr_lemma_manager::mk_congr(expr const & fn) -> optional<result> {
|
||||
return m_ptr->mk_congr(fn);
|
||||
}
|
||||
auto congr_lemma_manager::mk_congr(expr const & fn, unsigned nargs) -> optional<result> {
|
||||
return m_ptr->mk_congr(fn, nargs);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,5 +41,8 @@ public:
|
|||
|
||||
optional<result> mk_congr_simp(expr const & fn);
|
||||
optional<result> mk_congr_simp(expr const & fn, unsigned nargs);
|
||||
|
||||
optional<result> mk_congr(expr const & fn);
|
||||
optional<result> mk_congr(expr const & fn, unsigned nargs);
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue