diff --git a/src/library/blast/forward/ematch.cpp b/src/library/blast/forward/ematch.cpp index 4bc451858..daf111fc3 100644 --- a/src/library/blast/forward/ematch.cpp +++ b/src/library/blast/forward/ematch.cpp @@ -297,18 +297,36 @@ struct ematch_fn { if (!cg_lemma) return false; buffer t_args; - get_app_args(t, t_args); + expr const & fn = get_app_args(t, t_args); if (p_args.size() != t_args.size()) return false; - auto const * r_names = &cg_lemma->m_rel_names; - for (unsigned i = 0; i < p_args.size(); i++) { - lean_assert(*r_names); - if (auto Rc = head(*r_names)) { - s = cons(entry(*Rc, Match, p_args[i], t_args[i]), s); - } else { - s = cons(entry(get_eq_name(), DefEqOnly, p_args[i], t_args[i]), s); + if (cg_lemma->m_hcongr_lemma) { + /* Lemma was created using mk_hcongr_lemma */ + lean_assert(is_standard(env())); + fun_info finfo = get_fun_info(fn, t_args.size()); + list const * pinfos = &finfo.get_params_info(); + lean_assert(length(*pinfos) == t_args.size()); + for (unsigned i = 0; i < t_args.size(); i++) { + param_info const & pinfo = head(*pinfos); + if (!pinfo.is_inst_implicit() && !pinfo.is_implicit() && !pinfo.is_subsingleton()) { + /* We only match explicit arguments that are *not* subsingletons */ + s = cons(entry(get_eq_name(), Match, p_args[i], t_args[i]), s); + } else { + s = cons(entry(get_eq_name(), DefEqOnly, p_args[i], t_args[i]), s); + } + pinfos = &tail(*pinfos); + } + } else { + auto const * r_names = &cg_lemma->m_rel_names; + for (unsigned i = 0; i < p_args.size(); i++) { + lean_assert(*r_names); + if (auto Rc = head(*r_names)) { + s = cons(entry(*Rc, Match, p_args[i], t_args[i]), s); + } else { + s = cons(entry(get_eq_name(), DefEqOnly, p_args[i], t_args[i]), s); + } + r_names = &tail(*r_names); } - r_names = &tail(*r_names); } return true; } diff --git a/tests/lean/run/blast_ematch_heq1.lean b/tests/lean/run/blast_ematch_heq1.lean new file mode 100644 index 000000000..67df52e1b --- /dev/null +++ b/tests/lean/run/blast_ematch_heq1.lean @@ -0,0 +1,15 @@ +import data.nat +open algebra nat + +section +open nat +set_option blast.strategy "ematch" +set_option blast.cc.heq true + +attribute add.comm [forward] +attribute add.assoc [forward] + +example (a b c : nat) : a + b + b + c = c + b + a + b := +by blast + +end diff --git a/tests/lean/run/blast_ematch_heq2.lean b/tests/lean/run/blast_ematch_heq2.lean new file mode 100644 index 000000000..57b6d5dea --- /dev/null +++ b/tests/lean/run/blast_ematch_heq2.lean @@ -0,0 +1,10 @@ +import algebra.group + +variable {A : Type} +variable [s : group A] +include s + +set_option blast.cc.heq true + +example (a : A) : a * 1⁻¹ = a := +by inst_simp