diff --git a/src/library/blast/simplifier/simplifier.cpp b/src/library/blast/simplifier/simplifier.cpp index 87e00357b..ec3f5d75f 100644 --- a/src/library/blast/simplifier/simplifier.cpp +++ b/src/library/blast/simplifier/simplifier.cpp @@ -168,10 +168,13 @@ class simplifier { typedef std::unordered_map simplify_cache; simplify_cache m_cache; - optional cache_lookup(expr const & e); void cache_save(expr const & e, result const & r); + /* Mapping from subsingleton type to representative */ + expr_map m_subsingleton_elem_map; + optional normalize_subsingleton_args(expr const & e); + /* Basic helpers */ bool using_eq() { return m_rel == get_eq_name(); } @@ -370,7 +373,8 @@ expr simplifier::whnf_eta(expr const & e) { result simplifier::simplify(expr const & e, simp_lemmas const & srss) { flet set_srss(m_srss, srss); - freset reset(m_cache); + freset reset1(m_cache); + freset> reset2(m_subsingleton_elem_map); return simplify(e, true); } @@ -458,7 +462,22 @@ result simplifier::simplify_lambda(expr const & e) { t = instantiate(binding_body(t), l); } - result r = simplify(t, false); + result r = simplify(t, false); + expr new_t = r.get_new(); + /* check if subsingleton, and normalize */ + expr new_t_type = m_tmp_tctx->infer(new_t); + if (m_tmp_tctx->mk_subsingleton_instance(new_t_type)) { + auto it = m_subsingleton_elem_map.find(new_t_type); + if (it != m_subsingleton_elem_map.end()) { + if (it->second != new_t) { + expr proof = get_app_builder().mk_app(get_subsingleton_elim_name(), new_t, it->second); + r = join(r, result(it->second, proof)); + } + } else { + m_subsingleton_elem_map.insert(mk_pair(new_t_type, new_t)); + } + } + for (int i = ls.size() - 1; i >= 0; --i) r = funext(r, ls[i]); return r; } @@ -797,6 +816,61 @@ bool simplifier::instantiate_emetas(blast_tmp_type_context & tmp_tctx, unsigned return !failed; } +/* Given a function application \c e, replace arguments that are subsingletons with a + representative */ +optional simplifier::normalize_subsingleton_args(expr const & e) { + buffer args; + get_app_args(e, args); + auto congr_lemma = mk_specialized_congr_lemma(e); + if (!congr_lemma) return optional(); + expr proof = congr_lemma->get_proof(); + expr type = congr_lemma->get_type(); + unsigned i = 0; + bool normalized = false; + for_each(congr_lemma->get_arg_kinds(), [&](congr_arg_kind const & ckind) { + expr rfl; + switch (ckind) { + case congr_arg_kind::Fixed: + proof = mk_app(proof, args[i]); + type = instantiate(binding_body(type), args[i]); + break; + case congr_arg_kind::FixedNoParam: + break; + case congr_arg_kind::Eq: + proof = mk_app(proof, args[i]); + type = instantiate(binding_body(type), args[i]); + rfl = get_app_builder().mk_eq_refl(args[i]); + proof = mk_app(proof, args[i], rfl); + type = instantiate(binding_body(type), args[i]); + type = instantiate(binding_body(type), rfl); + break; + case congr_arg_kind::Cast: + proof = mk_app(proof, args[i]); + type = instantiate(binding_body(type), args[i]); + expr const & arg_type = binding_domain(type); + expr new_arg; + auto it = m_subsingleton_elem_map.find(arg_type); + if (it != m_subsingleton_elem_map.end()) { + normalized = (it->second != args[i]); + new_arg = it->second; + } else { + new_arg = args[i]; + m_subsingleton_elem_map.insert(mk_pair(arg_type, args[i])); + } + proof = mk_app(proof, new_arg); + type = instantiate(binding_body(type), new_arg); + break; + } + i++; + }); + if (!normalized) return optional(); + lean_assert(is_eq(type)); + buffer type_args; + get_app_args(type, type_args); + expr e_new = type_args[2]; + return optional(result(e_new, proof)); +} + template optional simplifier::synth_congr(expr const & e, F && simp) { static_assert(std::is_same::type, result>::value, @@ -843,8 +917,15 @@ optional simplifier::synth_congr(expr const & e, F && simp) { buffer type_args; get_app_args(type, type_args); expr e_new = remove_unnecessary_casts(type_args[2]); - if (has_proof) return optional(result(e_new, proof)); - else return optional(result(e_new)); + result r; + if (has_proof) r = result(e_new, proof); + else r = result(e_new); + + if (has_cast) { + if (auto r_norm = normalize_subsingleton_args(e_new)) + r = join(r, *r_norm); + } + return optional(r); } expr simplifier::remove_unnecessary_casts(expr const & e) { diff --git a/tests/lean/run/blast_simp_subsingleton2.lean b/tests/lean/run/blast_simp_subsingleton2.lean new file mode 100644 index 000000000..f8fd19946 --- /dev/null +++ b/tests/lean/run/blast_simp_subsingleton2.lean @@ -0,0 +1,16 @@ +import data.unit +open nat unit + +constant r {A B : Type} : A → B → A + +example (a b c d : unit) : r a b = r c d := +by simp + +example (a b : unit) : a = b := +by simp + +example (a b : unit) : (λ x : nat, a) = (λ y : nat, b) := +by simp + +example (a b : unit) : (λ x : nat, r a b) = (λ y : nat, r b b) := +by simp