feat(library/blast/simplifier/simplifier): subsingleton normalization for application arguments and lambdas

This commit is contained in:
Leonardo de Moura 2016-01-06 17:30:08 -08:00
parent e7bcb89314
commit c9930d0a29
2 changed files with 102 additions and 5 deletions

View file

@ -168,10 +168,13 @@ class simplifier {
typedef std::unordered_map<key, result, key_hash_fn, key_eq_fn> simplify_cache;
simplify_cache m_cache;
optional<result> cache_lookup(expr const & e);
void cache_save(expr const & e, result const & r);
/* Mapping from subsingleton type to representative */
expr_map<expr> m_subsingleton_elem_map;
optional<result> normalize_subsingleton_args(expr const & e);
/* Basic helpers */
bool using_eq() { return m_rel == get_eq_name(); }
@ -370,7 +373,8 @@ expr simplifier::whnf_eta(expr const & e) {
result simplifier::simplify(expr const & e, simp_lemmas const & srss) {
flet<simp_lemmas> set_srss(m_srss, srss);
freset<simplify_cache> reset(m_cache);
freset<simplify_cache> reset1(m_cache);
freset<expr_map<expr>> reset2(m_subsingleton_elem_map);
return simplify(e, true);
}
@ -458,7 +462,22 @@ result simplifier::simplify_lambda(expr const & e) {
t = instantiate(binding_body(t), l);
}
result r = simplify(t, false);
result r = simplify(t, false);
expr new_t = r.get_new();
/* check if subsingleton, and normalize */
expr new_t_type = m_tmp_tctx->infer(new_t);
if (m_tmp_tctx->mk_subsingleton_instance(new_t_type)) {
auto it = m_subsingleton_elem_map.find(new_t_type);
if (it != m_subsingleton_elem_map.end()) {
if (it->second != new_t) {
expr proof = get_app_builder().mk_app(get_subsingleton_elim_name(), new_t, it->second);
r = join(r, result(it->second, proof));
}
} else {
m_subsingleton_elem_map.insert(mk_pair(new_t_type, new_t));
}
}
for (int i = ls.size() - 1; i >= 0; --i) r = funext(r, ls[i]);
return r;
}
@ -797,6 +816,61 @@ bool simplifier::instantiate_emetas(blast_tmp_type_context & tmp_tctx, unsigned
return !failed;
}
/* Given a function application \c e, replace arguments that are subsingletons with a
representative */
optional<result> simplifier::normalize_subsingleton_args(expr const & e) {
buffer<expr> args;
get_app_args(e, args);
auto congr_lemma = mk_specialized_congr_lemma(e);
if (!congr_lemma) return optional<result>();
expr proof = congr_lemma->get_proof();
expr type = congr_lemma->get_type();
unsigned i = 0;
bool normalized = false;
for_each(congr_lemma->get_arg_kinds(), [&](congr_arg_kind const & ckind) {
expr rfl;
switch (ckind) {
case congr_arg_kind::Fixed:
proof = mk_app(proof, args[i]);
type = instantiate(binding_body(type), args[i]);
break;
case congr_arg_kind::FixedNoParam:
break;
case congr_arg_kind::Eq:
proof = mk_app(proof, args[i]);
type = instantiate(binding_body(type), args[i]);
rfl = get_app_builder().mk_eq_refl(args[i]);
proof = mk_app(proof, args[i], rfl);
type = instantiate(binding_body(type), args[i]);
type = instantiate(binding_body(type), rfl);
break;
case congr_arg_kind::Cast:
proof = mk_app(proof, args[i]);
type = instantiate(binding_body(type), args[i]);
expr const & arg_type = binding_domain(type);
expr new_arg;
auto it = m_subsingleton_elem_map.find(arg_type);
if (it != m_subsingleton_elem_map.end()) {
normalized = (it->second != args[i]);
new_arg = it->second;
} else {
new_arg = args[i];
m_subsingleton_elem_map.insert(mk_pair(arg_type, args[i]));
}
proof = mk_app(proof, new_arg);
type = instantiate(binding_body(type), new_arg);
break;
}
i++;
});
if (!normalized) return optional<result>();
lean_assert(is_eq(type));
buffer<expr> type_args;
get_app_args(type, type_args);
expr e_new = type_args[2];
return optional<result>(result(e_new, proof));
}
template<typename F>
optional<result> simplifier::synth_congr(expr const & e, F && simp) {
static_assert(std::is_same<typename std::result_of<F(expr const & e)>::type, result>::value,
@ -843,8 +917,15 @@ optional<result> simplifier::synth_congr(expr const & e, F && simp) {
buffer<expr> type_args;
get_app_args(type, type_args);
expr e_new = remove_unnecessary_casts(type_args[2]);
if (has_proof) return optional<result>(result(e_new, proof));
else return optional<result>(result(e_new));
result r;
if (has_proof) r = result(e_new, proof);
else r = result(e_new);
if (has_cast) {
if (auto r_norm = normalize_subsingleton_args(e_new))
r = join(r, *r_norm);
}
return optional<result>(r);
}
expr simplifier::remove_unnecessary_casts(expr const & e) {

View file

@ -0,0 +1,16 @@
import data.unit
open nat unit
constant r {A B : Type} : A → B → A
example (a b c d : unit) : r a b = r c d :=
by simp
example (a b : unit) : a = b :=
by simp
example (a b : unit) : (λ x : nat, a) = (λ y : nat, b) :=
by simp
example (a b : unit) : (λ x : nat, r a b) = (λ y : nat, r b b) :=
by simp