diff --git a/src/library/tactic/inversion_tactic.cpp b/src/library/tactic/inversion_tactic.cpp index 049ea8262..d45a4cec7 100644 --- a/src/library/tactic/inversion_tactic.cpp +++ b/src/library/tactic/inversion_tactic.cpp @@ -20,12 +20,13 @@ Author: Leonardo de Moura namespace lean { namespace inversion { -result::result(list const & gs, list const & num_args, list const & imps, - name_generator const & ngen, substitution const & subst): - m_goals(gs), m_num_args(num_args), m_implementation_lists(imps), - m_ngen(ngen), m_subst(subst) { - lean_assert_eq(length(m_goals), length(m_num_args)); +result::result(list const & gs, list> const & args, list const & imps, + list const & rs, name_generator const & ngen, substitution const & subst): + m_goals(gs), m_args(args), m_implementation_lists(imps), + m_renames(rs), m_ngen(ngen), m_subst(subst) { + lean_assert_eq(length(m_goals), length(m_args)); lean_assert_eq(length(m_goals), length(m_implementation_lists)); + lean_assert_eq(length(m_goals), length(m_renames)); } } @@ -75,7 +76,6 @@ optional apply_eq_rec_eq(type_checker & tc, io_state const & ios, listupdate_exprs([&](expr const & e) { return abstract_locals(e, nlocals, locals); }); @@ -308,11 +308,12 @@ class inversion_tac { } } - std::pair, list> intros_minors_args(list gs) { + std::pair, list>> intros_minors_args(list gs) { buffer minors_nargs; get_minors_nargs(minors_nargs); lean_assert(length(gs) == minors_nargs.size()); buffer new_gs; + buffer> new_args; for (unsigned i = 0; i < minors_nargs.size(); i++) { goal const & g = head(gs); unsigned nargs = minors_nargs[i]; @@ -321,6 +322,7 @@ class inversion_tac { buffer new_hyps; new_hyps.append(hyps); expr g_type = g.get_type(); + buffer curr_new_args; for (unsigned i = 0; i < nargs; i++) { expr type = binding_domain(g_type); name new_h_name; @@ -331,9 +333,11 @@ class inversion_tac { new_h_name = binding_name(g_type); } expr new_h = mk_local(m_ngen.next(), get_unused_name(new_h_name, new_hyps), type, binder_info()); + curr_new_args.push_back(mlocal_name(new_h)); new_hyps.push_back(new_h); g_type = instantiate(binding_body(g_type), new_h); } + new_args.push_back(to_list(curr_new_args)); g_type = head_beta_reduce(g_type); expr new_meta = mk_app(mk_metavar(m_ngen.next(), Pi(new_hyps, g_type)), new_hyps); goal new_g(new_meta, g_type); @@ -342,7 +346,7 @@ class inversion_tac { assign(g.get_name(), val); gs = tail(gs); } - return mk_pair(to_list(new_gs), to_list(minors_nargs)); + return mk_pair(to_list(new_gs), to_list(new_args)); } struct inversion_exception : public exception { @@ -514,11 +518,21 @@ class inversion_tac { } } - typedef optional> unify_result; + rename_map m_renames; + implementation_list m_imps; - unify_result unify_eqs(goal g, implementation_list imps, unsigned neqs) { + // update m_renames with old_hyps --> new_hyps. + void store_renames(buffer const & old_hyps, buffer const & new_hyps) { + lean_assert(old_hyps.size() == new_hyps.size()); + for (unsigned i = 0; i < old_hyps.size(); i++) { + m_renames.insert(mlocal_name(old_hyps[i]), mlocal_name(new_hyps[i])); + } + } + + // Remark: it also updates m_renames and m_imps + optional unify_eqs(goal g, unsigned neqs) { if (neqs == 0) - return unify_result(g, imps); // done + return optional(g); // done g = intro_next_eq(g); buffer hyps; g.get_hyps(hyps); @@ -538,7 +552,7 @@ class inversion_tac { goal new_g(new_meta, new_type); expr val = g.abstract(new_meta); assign(g.get_name(), val); - return unify_eqs(new_g, imps, neqs-1); + return unify_eqs(new_g, neqs-1); } buffer lhs_args, rhs_args; expr const & lhs_fn = get_app_args(lhs, lhs_args); @@ -569,12 +583,12 @@ class inversion_tac { assign(g.get_name(), val); unsigned A_nparams = *inductive::get_num_params(m_env, const_name(A_fn)); lean_assert(lhs_args.size() >= A_nparams); - return unify_eqs(new_g, imps, neqs - 1 + lhs_args.size() - A_nparams); + return unify_eqs(new_g, neqs - 1 + lhs_args.size() - A_nparams); } else { // conflict transition, eq is of the form c_1 ... = c_2 ..., where c_1 and c_2 are different constructors/intro rules. expr val = g.abstract(lift_down(no_confusion)); assign(g.get_name(), val); - return unify_result(); // goal has been solved + return optional(); // goal has been solved } } if (is_local(rhs)) { @@ -602,7 +616,7 @@ class inversion_tac { buffer non_deps, deps; split_deps(hyps, rhs, non_deps, deps); expr deps_g_type = Pi(deps, g_type); - abstract_locals(imps, deps); + abstract_locals(m_imps, deps); level eq_rec_lvl1 = sort_level(m_tc.ensure_type(deps_g_type).first); level eq_rec_lvl2 = sort_level(m_tc.ensure_type(A).first); expr tformer; @@ -615,19 +629,23 @@ class inversion_tac { buffer new_hyps; new_hyps.append(non_deps); expr new_type = instantiate(abstract_local(deps_g_type, rhs), lhs); - abstract_local(imps, rhs); - instantiate(imps, lhs); + abstract_local(m_imps, rhs); + instantiate(m_imps, lhs); if (!m_proof_irrel) { new_type = abstract_local(new_type, eq); new_type = instantiate(new_type, mk_refl(m_tc, lhs)); } + buffer new_deps; for (unsigned i = 0; i < deps.size(); i++) { expr new_hyp = mk_local(m_ngen.next(), binding_name(new_type), binding_domain(new_type), binding_info(new_type)); new_hyps.push_back(new_hyp); + new_deps.push_back(new_hyp); new_type = instantiate(binding_body(new_type), new_hyp); - instantiate(imps, new_hyp); + instantiate(m_imps, new_hyp); } + lean_assert(deps.size() == new_deps.size()); + store_renames(deps, new_deps); expr new_mvar = mk_metavar(m_ngen.next(), Pi(new_hyps, new_type)); expr new_meta = mk_app(new_mvar, new_hyps); goal new_g(new_meta, new_type); @@ -635,7 +653,7 @@ class inversion_tac { eq_rec = mk_app(eq_rec, eq_rec_minor, rhs, eq); expr val = g.abstract(mk_app(eq_rec, deps)); assign(g.get_name(), val); - return unify_eqs(new_g, imps, neqs-1); + return unify_eqs(new_g, neqs-1); } else if (is_local(lhs)) { // flip equation and reduce to previous case if (m_proof_irrel) @@ -650,28 +668,33 @@ class inversion_tac { symm_pr = mk_app(symm_pr, A, lhs, rhs, eq); expr val = g.abstract(mk_app(new_meta, symm_pr)); assign(g.get_name(), val); - return unify_eqs(new_g, imps, neqs); + return unify_eqs(new_g, neqs); } throw inversion_exception("unification failed"); } - auto unify_eqs(list const & gs, list nargs, list imps) -> - std::tuple, list, list> { + auto unify_eqs(list const & gs, list> args, list imps) -> + std::tuple, list>, list, list> { lean_assert(length(gs) == length(imps)); unsigned neqs = m_nindices + (m_dep_elim ? 1 : 0); buffer new_goals; - buffer new_nargs; + buffer> new_args; buffer new_imps; + buffer new_renames; for (goal const & g : gs) { - if (auto g_imp_pair = unify_eqs(g, head(imps), neqs)) { - new_goals.push_back(g_imp_pair->first); - new_nargs.push_back(head(nargs)); - new_imps.push_back(g_imp_pair->second); + flet set1(m_renames, rename_map()); + flet set2(m_imps, head(imps)); + if (optional new_g = unify_eqs(g, neqs)) { + new_goals.push_back(*new_g); + list new_as = map(head(args), [&](name const & n) { return m_renames.find(n); }); + new_args.push_back(new_as); + new_imps.push_back(m_imps); + new_renames.push_back(m_renames); } imps = tail(imps); - nargs = tail(nargs); + args = tail(args); } - return std::make_tuple(to_list(new_goals), to_list(new_nargs), to_list(new_imps)); + return std::make_tuple(to_list(new_goals), to_list(new_args), to_list(new_imps), to_list(new_renames)); } public: @@ -685,15 +708,10 @@ public: inversion_tac(environment const & env, io_state const & ios, type_checker & tc): inversion_tac(env, ios, tc.mk_ngen(), tc, substitution(), list()) {} - typedef inversion::result result; - optional execute(goal const & g, name const & n, implementation_list const & imps) { + optional execute(goal const & g, expr const & h, implementation_list const & imps) { try { - auto p = g.find_hyp(n); - if (!p) - return optional(); - expr const & h = p->first; expr h_type = m_tc.whnf(mlocal_type(h)).first; if (!is_inversion_applicable(h_type)) return optional(); @@ -701,22 +719,31 @@ public: auto gs_imps_pair = apply_cases_on(g1, imps); list gs2 = gs_imps_pair.first; list new_imps = gs_imps_pair.second; - auto gs_nargs_pair = intros_minors_args(gs2); - list gs3 = gs_nargs_pair.first; - list nargs = gs_nargs_pair.second; + auto gs_args_pair = intros_minors_args(gs2); + list gs3 = gs_args_pair.first; + list> args = gs_args_pair.second; list gs4; - std::tie(gs4, nargs, new_imps) = unify_eqs(gs3, nargs, new_imps); - return optional(result(gs4, nargs, new_imps, m_ngen, m_subst)); + list rs; + std::tie(gs4, args, new_imps, rs) = unify_eqs(gs3, args, new_imps); + return optional(result(gs4, args, new_imps, rs, m_ngen, m_subst)); } catch (inversion_exception & ex) { return optional(); } } + + optional execute(goal const & g, name const & n, implementation_list const & imps) { + auto p = g.find_hyp(n); + if (!p) + return optional(); + expr const & h = p->first; + return execute(g, h, imps); + } }; namespace inversion { optional apply(environment const & env, io_state const & ios, type_checker & tc, - goal const & g, name const & n, implementation_list const & imps) { - return inversion_tac(env, ios, tc).execute(g, n, imps); + goal const & g, expr const & h, implementation_list const & imps) { + return inversion_tac(env, ios, tc).execute(g, h, imps); } } diff --git a/src/library/tactic/inversion_tactic.h b/src/library/tactic/inversion_tactic.h index 4a4f30162..7c2cf9331 100644 --- a/src/library/tactic/inversion_tactic.h +++ b/src/library/tactic/inversion_tactic.h @@ -7,6 +7,7 @@ Author: Leonardo de Moura #pragma once #include #include +#include "util/name_map.h" #include "library/tactic/tactic.h" namespace lean { @@ -43,19 +44,22 @@ typedef list implementation_list; struct result { list m_goals; - list m_num_args; + list> m_args; // arguments of the constructor/intro rule list m_implementation_lists; - // invariant: length(m_goals) == length(m_num_args); + list m_renames; + // invariant: length(m_goals) == length(m_args); // invariant: length(m_goals) == length(m_implementation_lists); + // invariant: length(m_goals) == length(m_renames); name_generator m_ngen; substitution m_subst; - result(list const & gs, list const & num_args, list const & imps, - name_generator const & ngen, substitution const & subst); + result(list const & gs, list> const & args, list const & imps, + list const & rs, name_generator const & ngen, substitution const & subst); }; optional apply(environment const & env, io_state const & ios, type_checker & tc, - goal const & g, name const & n, implementation_list const & imps); + goal const & g, expr const & h, implementation_list const & imps); } + tactic inversion_tactic(name const & n, list const & ids = list()); void initialize_inversion_tactic(); void finalize_inversion_tactic();