feat(library/congr_lemma_manager): add congruence lemma for normalization and congruence closure

This commit is contained in:
Leonardo de Moura 2015-11-12 18:54:12 -08:00
parent 15cd92f89f
commit 1d1cd0fc24
2 changed files with 150 additions and 4 deletions

View file

@ -8,6 +8,7 @@ Author: Leonardo de Moura
#include "kernel/abstract.h" #include "kernel/abstract.h"
#include "library/util.h" #include "library/util.h"
#include "library/locals.h" #include "library/locals.h"
#include "library/constants.h"
#include "library/replace_visitor.h" #include "library/replace_visitor.h"
#include "library/congr_lemma_manager.h" #include "library/congr_lemma_manager.h"
@ -35,6 +36,7 @@ class congr_lemma_manager::imp {
} }
}; };
std::unordered_map<key, result, key_hash_fn, key_eq_fn> m_simp_cache;
std::unordered_map<key, result, key_hash_fn, key_eq_fn> m_cache; std::unordered_map<key, result, key_hash_fn, key_eq_fn> m_cache;
expr infer(expr const & e) { return m_ctx.infer(e); } expr infer(expr const & e) { return m_ctx.infer(e); }
@ -192,6 +194,100 @@ class congr_lemma_manager::imp {
} }
} }
optional<result> mk_congr(expr const & fn, optional<result> const & simp_lemma,
buffer<param_info> const & pinfos, buffer<congr_arg_kind> const & kinds) {
try {
expr fn_type1 = whnf(infer(fn));
expr fn_type2 = fn_type1;
name e_name("e");
buffer<expr> lhss;
buffer<expr> rhss; // it contains the right-hand-side argument
buffer<optional<expr>> eqs; // for Eq args, it contains the equality
buffer<expr> hyps; // contains lhss + rhss + eqs
buffer<expr> simp_lemma_args;
for (unsigned i = 0; i < pinfos.size(); i++) {
if (!is_pi(fn_type1)) {
return optional<result>();
}
expr lhs = m_ctx.mk_tmp_local(binding_name(fn_type1), binding_domain(fn_type1));
expr rhs;
lhss.push_back(lhs);
hyps.push_back(lhs);
simp_lemma_args.push_back(lhs);
switch (kinds[i]) {
case congr_arg_kind::Eq: {
lean_assert(m_ctx.is_def_eq(binding_domain(fn_type1), binding_domain(fn_type2)));
rhs = m_ctx.mk_tmp_local(binding_name(fn_type2), binding_domain(fn_type2));
expr eq_type = m_builder.mk_eq(lhs, rhs);
rhss.push_back(rhs);
expr eq = m_ctx.mk_tmp_local(e_name.append_after(eqs.size()+1), eq_type);
eqs.push_back(some_expr(eq));
hyps.push_back(rhs);
hyps.push_back(eq);
simp_lemma_args.push_back(rhs);
simp_lemma_args.push_back(eq);
break;
}
case congr_arg_kind::Fixed:
rhs = lhs;
rhss.push_back(rhs);
eqs.push_back(none_expr());
break;
case congr_arg_kind::Cast: {
rhs = m_ctx.mk_tmp_local(binding_name(fn_type2), binding_domain(fn_type2));
rhss.push_back(rhs);
eqs.push_back(none_expr());
hyps.push_back(rhs);
break;
}}
fn_type1 = whnf(instantiate(binding_body(fn_type1), lhs));
fn_type2 = whnf(instantiate(binding_body(fn_type2), rhs));
}
expr pr1 = mk_app(simp_lemma->get_proof(), simp_lemma_args);
expr type1 = simp_lemma->get_type();
while (is_pi(type1))
type1 = binding_body(type1);
type1 = instantiate_rev(type1, simp_lemma_args.size(), simp_lemma_args.data());
expr lhs1, rhs1;
lean_verify(is_eq(type1, lhs1, rhs1));
// build proof2
expr rhs2 = mk_app(fn, rhss);
expr eq = m_builder.mk_eq(lhs1, rhs2);
expr congr_type = Pi(hyps, eq);
// build proof that rhs1 = rhs2
unsigned i;
for (i = 0; i < kinds.size(); i++) {
if (kinds[i] == congr_arg_kind::Cast && !pinfos[i].is_prop())
break;
}
if (i == kinds.size()) {
// rhs1 and rhs2 are definitionally equal
expr congr_proof = Fun(hyps, pr1);
return optional<result>(congr_type, congr_proof, to_list(kinds));
}
buffer<expr> rhss1;
get_app_args(rhs1, rhss1);
lean_assert(rhss.size() == rhss1.size());
expr a = mk_app(fn, i, rhss1.data());
expr pr2 = m_builder.mk_eq_refl(a);
for (; i < kinds.size(); i++) {
if (kinds[i] == congr_arg_kind::Cast && !pinfos[i].is_prop()) {
lean_assert(pinfos[i].is_subsingleton());
expr r1 = rhss1[i];
expr r2 = rhss[i];
expr r1_eq_r2 = m_builder.mk_app(get_subsingleton_elim_name(), r1, r2);
pr2 = m_builder.mk_congr(pr2, r1_eq_r2);
} else {
pr2 = m_builder.mk_congr_fun(pr2, rhss[i]);
}
}
expr congr_proof = Fun(hyps, m_builder.mk_eq_trans(pr1, pr2));
return optional<result>(congr_type, congr_proof, to_list(kinds));
} catch (app_builder_exception &) {
return optional<result>();
}
}
public: public:
imp(app_builder & b, fun_info_manager & fm, bool ignore_inst_implicit): imp(app_builder & b, fun_info_manager & fm, bool ignore_inst_implicit):
m_builder(b), m_fmanager(fm), m_ctx(fm.ctx()), m_ignore_inst_implicit(ignore_inst_implicit) {} m_builder(b), m_fmanager(fm), m_ctx(fm.ctx()), m_ignore_inst_implicit(ignore_inst_implicit) {}
@ -202,8 +298,8 @@ public:
} }
optional<result> mk_congr_simp(expr const & fn, unsigned nargs) { optional<result> mk_congr_simp(expr const & fn, unsigned nargs) {
auto r = m_cache.find(key(fn, nargs)); auto r = m_simp_cache.find(key(fn, nargs));
if (r != m_cache.end()) if (r != m_simp_cache.end())
return optional<result>(r->second); return optional<result>(r->second);
fun_info finfo = m_fmanager.get(fn, nargs); fun_info finfo = m_fmanager.get(fn, nargs);
list<unsigned> const & result_deps = finfo.get_dependencies(); list<unsigned> const & result_deps = finfo.get_dependencies();
@ -238,7 +334,7 @@ public:
} }
auto new_r = mk_congr_simp(fn, pinfos, kinds); auto new_r = mk_congr_simp(fn, pinfos, kinds);
if (new_r) { if (new_r) {
m_cache.insert(mk_pair(key(fn, nargs), *new_r)); m_simp_cache.insert(mk_pair(key(fn, nargs), *new_r));
return new_r; return new_r;
} else if (has_cast(kinds)) { } else if (has_cast(kinds)) {
// remove casts and try again // remove casts and try again
@ -248,7 +344,7 @@ public:
} }
auto new_r = mk_congr_simp(fn, pinfos, kinds); auto new_r = mk_congr_simp(fn, pinfos, kinds);
if (new_r) { if (new_r) {
m_cache.insert(mk_pair(key(fn, nargs), *new_r)); m_simp_cache.insert(mk_pair(key(fn, nargs), *new_r));
return new_r; return new_r;
} else { } else {
return new_r; return new_r;
@ -257,6 +353,47 @@ public:
return new_r; return new_r;
} }
} }
optional<result> mk_congr(expr const & fn) {
fun_info finfo = m_fmanager.get(fn);
return mk_congr(fn, finfo.get_arity());
}
optional<result> mk_congr(expr const & fn, unsigned nargs) {
auto r = m_cache.find(key(fn, nargs));
if (r != m_cache.end())
return optional<result>(r->second);
fun_info finfo = m_fmanager.get(fn, nargs);
optional<result> simp_lemma = mk_congr_simp(fn, nargs);
if (!simp_lemma)
return optional<result>();
buffer<congr_arg_kind> kinds;
buffer<param_info> pinfos;
to_buffer(simp_lemma->get_arg_kinds(), kinds);
to_buffer(finfo.get_params_info(), pinfos);
// For congr lemmas we have the following restriction:
// if a Cast arg is subsingleton, it is not a proposition,
// and it is a dependent argument, then we mark it as fixed.
// This restriction doesn't affect the standard library,
// but it simplifies the implementation.
lean_assert(kinds.size() == pinfos.size());
bool has_cast = false;
for (unsigned i = 0; i < kinds.size(); i++) {
if (!pinfos[i].is_prop() && pinfos[i].is_subsingleton() && pinfos[i].is_dep()) {
kinds[i] = congr_arg_kind::Fixed;
}
if (kinds[i] == congr_arg_kind::Cast)
has_cast = true;
}
if (!has_cast) {
m_cache.insert(mk_pair(key(fn, nargs), *simp_lemma));
return simp_lemma; // simp_lemma will be identical to regular congr lemma
}
auto new_r = mk_congr(fn, simp_lemma, pinfos, kinds);
if (new_r)
m_cache.insert(mk_pair(key(fn, nargs), *new_r));
return new_r;
}
}; };
congr_lemma_manager::congr_lemma_manager(app_builder & b, fun_info_manager & fm, bool ignore_inst_implicit): congr_lemma_manager::congr_lemma_manager(app_builder & b, fun_info_manager & fm, bool ignore_inst_implicit):
@ -272,4 +409,10 @@ auto congr_lemma_manager::mk_congr_simp(expr const & fn) -> optional<result> {
auto congr_lemma_manager::mk_congr_simp(expr const & fn, unsigned nargs) -> optional<result> { auto congr_lemma_manager::mk_congr_simp(expr const & fn, unsigned nargs) -> optional<result> {
return m_ptr->mk_congr_simp(fn, nargs); return m_ptr->mk_congr_simp(fn, nargs);
} }
auto congr_lemma_manager::mk_congr(expr const & fn) -> optional<result> {
return m_ptr->mk_congr(fn);
}
auto congr_lemma_manager::mk_congr(expr const & fn, unsigned nargs) -> optional<result> {
return m_ptr->mk_congr(fn, nargs);
}
} }

View file

@ -41,5 +41,8 @@ public:
optional<result> mk_congr_simp(expr const & fn); optional<result> mk_congr_simp(expr const & fn);
optional<result> mk_congr_simp(expr const & fn, unsigned nargs); optional<result> mk_congr_simp(expr const & fn, unsigned nargs);
optional<result> mk_congr(expr const & fn);
optional<result> mk_congr(expr const & fn, unsigned nargs);
}; };
} }