From 1f13bfa4f7a5401cd5c32fdf2b2f8f7628e86edc Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 30 Dec 2014 21:22:50 -0800 Subject: [PATCH] feat(library/tactic/inversion_tactic): add inversion::apply procedure The new procedure is essentially a "customized" version of the inversion (aka cases) tactic for the equations package --- src/library/tactic/inversion_tactic.cpp | 229 ++++++++++++++++-------- src/library/tactic/inversion_tactic.h | 48 +++++ 2 files changed, 202 insertions(+), 75 deletions(-) diff --git a/src/library/tactic/inversion_tactic.cpp b/src/library/tactic/inversion_tactic.cpp index 24e0e5cef..049ea8262 100644 --- a/src/library/tactic/inversion_tactic.cpp +++ b/src/library/tactic/inversion_tactic.cpp @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ +#include #include "util/sstream.h" #include "kernel/abstract.h" #include "kernel/instantiate.h" @@ -15,8 +16,19 @@ Author: Leonardo de Moura #include "library/tactic/tactic.h" #include "library/tactic/expr_to_tactic.h" #include "library/tactic/class_instance_synth.h" +#include "library/tactic/inversion_tactic.h" namespace lean { +namespace inversion { +result::result(list const & gs, list const & num_args, list const & imps, + name_generator const & ngen, substitution const & subst): + m_goals(gs), m_num_args(num_args), m_implementation_lists(imps), + m_ngen(ngen), m_subst(subst) { + lean_assert_eq(length(m_goals), length(m_num_args)); + lean_assert_eq(length(m_goals), length(m_implementation_lists)); +} +} + /** \brief Given eq_rec of the form @eq.rec.{l₂ l₁} A a (λ (a' : A) (h : a = a'), B a') b a p, apply the eq_rec_eq definition to produce the equality @@ -61,14 +73,36 @@ optional apply_eq_rec_eq(type_checker & tc, io_state const & ios, listupdate_exprs([&](expr const & e) { return abstract_locals(e, nlocals, locals); }); + } +} + +static void instantiate(implementation_list const & imps, expr const & local) { + for (implementation_ptr const & imp : imps) { + imp->update_exprs([&](expr const & e) { return instantiate(e, local); }); + } +} + +static void abstract_locals(implementation_list const & imps, buffer const & locals) { + abstract_locals(imps, locals.size(), locals.data()); +} + +static void abstract_local(implementation_list const & imps, expr const & local) { + abstract_locals(imps, 1, &local); +} + class inversion_tac { environment const & m_env; io_state const & m_ios; - proof_state const & m_ps; + type_checker & m_tc; list m_ids; name_generator m_ngen; substitution m_subst; - std::unique_ptr m_tc; bool m_dep_elim; bool m_proof_irrel; @@ -109,11 +143,11 @@ class inversion_tac { } pair mk_eq(expr const & lhs, expr const & rhs) { - expr lhs_type = m_tc->infer(lhs).first; - expr rhs_type = m_tc->infer(rhs).first; - level l = sort_level(m_tc->ensure_type(lhs_type).first); + expr lhs_type = m_tc.infer(lhs).first; + expr rhs_type = m_tc.infer(rhs).first; + level l = sort_level(m_tc.ensure_type(lhs_type).first); constraint_seq cs; - if (m_tc->is_def_eq(lhs_type, rhs_type, justification(), cs) && !cs) { + if (m_tc.is_def_eq(lhs_type, rhs_type, justification(), cs) && !cs) { return mk_pair(mk_app(mk_constant("eq", to_list(l)), lhs_type, lhs, rhs), mk_app(mk_constant({"eq", "refl"}, to_list(l)), rhs_type, rhs)); } else { @@ -135,7 +169,7 @@ class inversion_tac { buffer I_args; expr const & I = get_app_args(h_type, I_args); expr h_new_type = mk_app(I, I_args.size() - m_nindices, I_args.data()); - expr d = m_tc->whnf(m_tc->infer(h_new_type).first).first; + expr d = m_tc.whnf(m_tc.infer(h_new_type).first).first; name t_prefix("t"); unsigned nidx = 1; if (m_proof_irrel) { @@ -182,7 +216,7 @@ class inversion_tac { expr t = mk_local(m_ngen.next(), g.get_unused_name(t_prefix, nidx), t_type, binder_info()); h_new_type = mk_app(h_new_type, t); ss.push_back(I_args[i]); - refls.push_back(mk_refl(*m_tc, I_args[i])); + refls.push_back(mk_refl(m_tc, I_args[i])); hyps.push_back(t); ts.push_back(t); d = instantiate(binding_body(d), t); @@ -190,10 +224,10 @@ class inversion_tac { expr h_new = mk_local(m_ngen.next(), h_new_name, h_new_type, local_info(h)); ts.push_back(h_new); ss.push_back(h); - refls.push_back(mk_refl(*m_tc, h)); + refls.push_back(mk_refl(m_tc, h)); hyps.push_back(h_new); buffer eqs; - mk_telescopic_eq(*m_tc, ss, ts, eqs); + mk_telescopic_eq(m_tc, ss, ts, eqs); ts.pop_back(); expr new_type = Pi(eqs, g.get_type()); expr new_meta = mk_app(mk_metavar(m_ngen.next(), Pi(hyps, new_type)), hyps); @@ -205,20 +239,21 @@ class inversion_tac { } } - list apply_cases_on(goal const & g) { + std::pair, list> apply_cases_on(goal const & g, implementation_list const & imps) { buffer hyps; g.get_hyps(hyps); expr const & h = hyps.back(); expr const & h_type = mlocal_type(h); buffer I_args; expr const & I = get_app_args(h_type, I_args); + name const & I_name = const_name(I); expr g_type = g.get_type(); expr cases_on; if (length(m_cases_on_decl.get_univ_params()) != length(m_I_decl.get_univ_params())) { - level g_lvl = sort_level(m_tc->ensure_type(g_type).first); - cases_on = mk_constant({const_name(I), "cases_on"}, cons(g_lvl, const_levels(I))); + level g_lvl = sort_level(m_tc.ensure_type(g_type).first); + cases_on = mk_constant({I_name, "cases_on"}, cons(g_lvl, const_levels(I))); } else { - cases_on = mk_constant({const_name(I), "cases_on"}, const_levels(I)); + cases_on = mk_constant({I_name, "cases_on"}, const_levels(I)); } // add params cases_on = mk_app(cases_on, m_nparams, I_args.data()); @@ -232,22 +267,28 @@ class inversion_tac { cases_on = mk_app(cases_on, m_nindices, I_args.end() - m_nindices); // add h cases_on = mk_app(cases_on, h); + buffer intro_names; + get_intro_rule_names(m_env, I_name, intro_names); + lean_assert(m_nminors == intro_names.size()); buffer new_hyps; new_hyps.append(hyps.size() - m_nindices - 1, hyps.data()); // add a subgoal for each minor premise of cases_on - expr cases_on_type = m_tc->whnf(m_tc->infer(cases_on).first).first; + expr cases_on_type = m_tc.whnf(m_tc.infer(cases_on).first).first; buffer new_goals; + buffer new_imps; for (unsigned i = 0; i < m_nminors; i++) { expr new_type = binding_domain(cases_on_type); expr new_meta = mk_app(mk_metavar(m_ngen.next(), Pi(new_hyps, new_type)), new_hyps); goal new_g(new_meta, new_type); new_goals.push_back(new_g); cases_on = mk_app(cases_on, new_meta); - cases_on_type = m_tc->whnf(binding_body(cases_on_type)).first; // the minor premises do not depend on each other + cases_on_type = m_tc.whnf(binding_body(cases_on_type)).first; // the minor premises do not depend on each other + name const & intro_name = intro_names[i]; + new_imps.push_back(filter(imps, [&](implementation_ptr const & imp) { return imp->get_constructor_name() == intro_name; })); } expr val = g.abstract(cases_on); assign(g.get_name(), val); - return to_list(new_goals.begin(), new_goals.end()); + return mk_pair(to_list(new_goals), to_list(new_imps)); } // Store in \c r the number of arguments for each cases_on minor. @@ -267,7 +308,7 @@ class inversion_tac { } } - list intros_minors_args(list gs) { + std::pair, list> intros_minors_args(list gs) { buffer minors_nargs; get_minors_nargs(minors_nargs); lean_assert(length(gs) == minors_nargs.size()); @@ -301,7 +342,7 @@ class inversion_tac { assign(g.get_name(), val); gs = tail(gs); } - return to_list(new_gs.begin(), new_gs.end()); + return mk_pair(to_list(new_gs), to_list(minors_nargs)); } struct inversion_exception : public exception { @@ -338,19 +379,19 @@ class inversion_tac { lean_assert(is_eq_rec(lhs)); // lhs is of the form (eq.rec A s C a s p) // aux_eq is a term of type ((eq.rec A s C a s p) = a) - auto aux_eq = apply_eq_rec_eq(*m_tc, m_ios, to_list(hyps), lhs); + auto aux_eq = apply_eq_rec_eq(m_tc, m_ios, to_list(hyps), lhs); if (!aux_eq) throw_unification_eq_rec_failure(); buffer lhs_args; get_app_args(lhs, lhs_args); expr const & reduced_lhs = lhs_args[3]; - expr new_eq = ::lean::mk_eq(*m_tc, reduced_lhs, rhs); + expr new_eq = ::lean::mk_eq(m_tc, reduced_lhs, rhs); expr new_type = update_binding(type, new_eq, binding_body(type)); expr new_meta = mk_app(mk_metavar(m_ngen.next(), Pi(hyps, new_type)), hyps); goal new_g(new_meta, new_type); // create assignment for g - expr A = m_tc->infer(lhs).first; - level lvl = sort_level(m_tc->ensure_type(A).first); + expr A = m_tc.infer(lhs).first; + level lvl = sort_level(m_tc.ensure_type(A).first); // old_eq : eq.rec A s C a s p = b expr old_eq = mk_local(m_ngen.next(), binding_name(type), eq, binder_info()); // aux_eq : a = eq.rec A s C a s p @@ -373,7 +414,7 @@ class inversion_tac { buffer args; expr const & heq_fn = get_app_args(eq, args); constraint_seq cs; - if (m_tc->is_def_eq(args[0], args[2], justification(), cs) && !cs) { + if (m_tc.is_def_eq(args[0], args[2], justification(), cs) && !cs) { buffer hyps; g.get_hyps(hyps); expr new_eq = mk_app(mk_constant("eq", const_levels(heq_fn)), args[0], args[1], args[3]); @@ -425,8 +466,8 @@ class inversion_tac { if (const_name(eq_fn) == "eq") { expr const & lhs = app_arg(app_fn(eq)); expr const & rhs = app_arg(eq); - expr new_lhs = m_tc->whnf(lhs).first; - expr new_rhs = m_tc->whnf(rhs).first; + expr new_lhs = m_tc.whnf(lhs).first; + expr new_rhs = m_tc.whnf(rhs).first; if (lhs != new_lhs || rhs != new_rhs) { eq = mk_app(app_fn(app_fn(eq)), new_lhs, new_rhs); type = update_binding(type, eq, binding_body(type)); @@ -461,7 +502,7 @@ class inversion_tac { // We must apply lift.down to eliminate the auxiliary lift. expr lift_down(expr const & v) { if (!m_proof_irrel) { - expr v_type = m_tc->whnf(m_tc->infer(v).first).first; + expr v_type = m_tc.whnf(m_tc.infer(v).first).first; if (!is_app(v_type)) throw_unification_eq_rec_failure(); expr const & lift = app_fn(v_type); @@ -473,9 +514,11 @@ class inversion_tac { } } - optional unify_eqs(goal g, unsigned neqs) { + typedef optional> unify_result; + + unify_result unify_eqs(goal g, implementation_list imps, unsigned neqs) { if (neqs == 0) - return optional(g); // done + return unify_result(g, imps); // done g = intro_next_eq(g); buffer hyps; g.get_hyps(hyps); @@ -483,11 +526,11 @@ class inversion_tac { expr eq = hyps.back(); buffer eq_args; get_app_args(mlocal_type(eq), eq_args); - expr const & A = m_tc->whnf(eq_args[0]).first; - expr lhs = m_tc->whnf(eq_args[1]).first; - expr rhs = m_tc->whnf(eq_args[2]).first; + expr const & A = m_tc.whnf(eq_args[0]).first; + expr lhs = m_tc.whnf(eq_args[1]).first; + expr rhs = m_tc.whnf(eq_args[2]).first; constraint_seq cs; - if (m_proof_irrel && m_tc->is_def_eq(lhs, rhs, justification(), cs) && !cs) { + if (m_proof_irrel && m_tc.is_def_eq(lhs, rhs, justification(), cs) && !cs) { // deletion transition: t == t hyps.pop_back(); // remove t == t equality expr new_type = g.get_type(); @@ -495,13 +538,13 @@ class inversion_tac { goal new_g(new_meta, new_type); expr val = g.abstract(new_meta); assign(g.get_name(), val); - return unify_eqs(new_g, neqs-1); + return unify_eqs(new_g, imps, neqs-1); } buffer lhs_args, rhs_args; expr const & lhs_fn = get_app_args(lhs, lhs_args); expr const & rhs_fn = get_app_args(rhs, rhs_args); expr const & g_type = g.get_type(); - level const & g_lvl = sort_level(m_tc->ensure_type(g_type).first); + level const & g_lvl = sort_level(m_tc.ensure_type(g_type).first); if (is_constant(lhs_fn) && is_constant(rhs_fn) && inductive::is_intro_rule(m_env, const_name(lhs_fn)) && @@ -516,7 +559,7 @@ class inversion_tac { expr no_confusion = mk_app(mk_app(mk_constant(no_confusion_name, cons(g_lvl, const_levels(A_fn))), A_args), g_type, lhs, rhs, eq); if (const_name(lhs_fn) == const_name(rhs_fn)) { // injectivity transition - expr new_type = binding_domain(m_tc->whnf(m_tc->infer(no_confusion).first).first); + expr new_type = binding_domain(m_tc.whnf(m_tc.infer(no_confusion).first).first); if (m_proof_irrel) hyps.pop_back(); // remove processed equality expr new_mvar = mk_metavar(m_ngen.next(), Pi(hyps, new_type)); @@ -526,12 +569,12 @@ class inversion_tac { assign(g.get_name(), val); unsigned A_nparams = *inductive::get_num_params(m_env, const_name(A_fn)); lean_assert(lhs_args.size() >= A_nparams); - return unify_eqs(new_g, neqs - 1 + lhs_args.size() - A_nparams); + return unify_eqs(new_g, imps, neqs - 1 + lhs_args.size() - A_nparams); } else { // conflict transition, eq is of the form c_1 ... = c_2 ..., where c_1 and c_2 are different constructors/intro rules. expr val = g.abstract(lift_down(no_confusion)); assign(g.get_name(), val); - return optional(); // goal has been solved + return unify_result(); // goal has been solved } } if (is_local(rhs)) { @@ -559,8 +602,9 @@ class inversion_tac { buffer non_deps, deps; split_deps(hyps, rhs, non_deps, deps); expr deps_g_type = Pi(deps, g_type); - level eq_rec_lvl1 = sort_level(m_tc->ensure_type(deps_g_type).first); - level eq_rec_lvl2 = sort_level(m_tc->ensure_type(A).first); + abstract_locals(imps, deps); + level eq_rec_lvl1 = sort_level(m_tc.ensure_type(deps_g_type).first); + level eq_rec_lvl2 = sort_level(m_tc.ensure_type(A).first); expr tformer; if (m_proof_irrel) tformer = Fun(rhs, deps_g_type); @@ -571,15 +615,18 @@ class inversion_tac { buffer new_hyps; new_hyps.append(non_deps); expr new_type = instantiate(abstract_local(deps_g_type, rhs), lhs); + abstract_local(imps, rhs); + instantiate(imps, lhs); if (!m_proof_irrel) { new_type = abstract_local(new_type, eq); - new_type = instantiate(new_type, mk_refl(*m_tc, lhs)); + new_type = instantiate(new_type, mk_refl(m_tc, lhs)); } for (unsigned i = 0; i < deps.size(); i++) { expr new_hyp = mk_local(m_ngen.next(), binding_name(new_type), binding_domain(new_type), binding_info(new_type)); new_hyps.push_back(new_hyp); new_type = instantiate(binding_body(new_type), new_hyp); + instantiate(imps, new_hyp); } expr new_mvar = mk_metavar(m_ngen.next(), Pi(new_hyps, new_type)); expr new_meta = mk_app(new_mvar, new_hyps); @@ -588,7 +635,7 @@ class inversion_tac { eq_rec = mk_app(eq_rec, eq_rec_minor, rhs, eq); expr val = g.abstract(mk_app(eq_rec, deps)); assign(g.get_name(), val); - return unify_eqs(new_g, neqs-1); + return unify_eqs(new_g, imps, neqs-1); } else if (is_local(lhs)) { // flip equation and reduce to previous case if (m_proof_irrel) @@ -598,65 +645,97 @@ class inversion_tac { expr new_mvar = mk_metavar(m_ngen.next(), Pi(hyps, new_type)); expr new_meta = mk_app(new_mvar, hyps); goal new_g(new_meta, new_type); - level eq_symm_lvl = sort_level(m_tc->ensure_type(A).first); + level eq_symm_lvl = sort_level(m_tc.ensure_type(A).first); expr symm_pr = mk_constant(name{"eq", "symm"}, {eq_symm_lvl}); symm_pr = mk_app(symm_pr, A, lhs, rhs, eq); expr val = g.abstract(mk_app(new_meta, symm_pr)); assign(g.get_name(), val); - return unify_eqs(new_g, neqs); + return unify_eqs(new_g, imps, neqs); } - // unification failed - return optional(g); + throw inversion_exception("unification failed"); } - list unify_eqs(list const & gs) { + auto unify_eqs(list const & gs, list nargs, list imps) -> + std::tuple, list, list> { + lean_assert(length(gs) == length(imps)); unsigned neqs = m_nindices + (m_dep_elim ? 1 : 0); - buffer new_goals; + buffer new_goals; + buffer new_nargs; + buffer new_imps; for (goal const & g : gs) { - if (optional new_g = unify_eqs(g, neqs)) - new_goals.push_back(*new_g); + if (auto g_imp_pair = unify_eqs(g, head(imps), neqs)) { + new_goals.push_back(g_imp_pair->first); + new_nargs.push_back(head(nargs)); + new_imps.push_back(g_imp_pair->second); + } + imps = tail(imps); + nargs = tail(nargs); } - return to_list(new_goals.begin(), new_goals.end()); + return std::make_tuple(to_list(new_goals), to_list(new_nargs), to_list(new_imps)); } public: - inversion_tac(environment const & env, io_state const & ios, proof_state const & ps, list const & ids): - m_env(env), m_ios(ios), m_ps(ps), m_ids(ids), - m_ngen(m_ps.get_ngen()), m_subst(m_ps.get_subst()), - m_tc(mk_type_checker(m_env, m_ngen.mk_child(), m_ps.relax_main_opaque())) { + inversion_tac(environment const & env, io_state const & ios, name_generator const & ngen, + type_checker & tc, substitution const & subst, list const & ids): + m_env(env), m_ios(ios), m_tc(tc), m_ids(ids), + m_ngen(ngen), m_subst(subst) { m_proof_irrel = m_env.prop_proof_irrel(); } - optional execute(name const & n) { + inversion_tac(environment const & env, io_state const & ios, type_checker & tc): + inversion_tac(env, ios, tc.mk_ngen(), tc, substitution(), list()) {} + + + typedef inversion::result result; + + optional execute(goal const & g, name const & n, implementation_list const & imps) { try { - goals const & gs = m_ps.get_goals(); - if (empty(gs)) - return none_proof_state(); - goal g = head(gs); - goals tail_gs = tail(gs); - auto p = g.find_hyp(n); + auto p = g.find_hyp(n); if (!p) - return none_proof_state(); - expr const & h = p->first; - expr h_type = m_tc->whnf(mlocal_type(h)).first; + return optional(); + expr const & h = p->first; + expr h_type = m_tc.whnf(mlocal_type(h)).first; if (!is_inversion_applicable(h_type)) - return none_proof_state(); - goal g1 = generalize_indices(g, h, h_type); - list gs2 = apply_cases_on(g1); - list gs3 = intros_minors_args(gs2); - list gs4 = unify_eqs(gs3); - proof_state new_s(m_ps, append(gs4, tail_gs), m_subst, m_ngen); - return some_proof_state(new_s); + return optional(); + goal g1 = generalize_indices(g, h, h_type); + auto gs_imps_pair = apply_cases_on(g1, imps); + list gs2 = gs_imps_pair.first; + list new_imps = gs_imps_pair.second; + auto gs_nargs_pair = intros_minors_args(gs2); + list gs3 = gs_nargs_pair.first; + list nargs = gs_nargs_pair.second; + list gs4; + std::tie(gs4, nargs, new_imps) = unify_eqs(gs3, nargs, new_imps); + return optional(result(gs4, nargs, new_imps, m_ngen, m_subst)); } catch (inversion_exception & ex) { - return none_proof_state(); + return optional(); } } }; +namespace inversion { +optional apply(environment const & env, io_state const & ios, type_checker & tc, + goal const & g, name const & n, implementation_list const & imps) { + return inversion_tac(env, ios, tc).execute(g, n, imps); +} +} + tactic inversion_tactic(name const & n, list const & ids) { auto fn = [=](environment const & env, io_state const & ios, proof_state const & ps) -> optional { - inversion_tac tac(env, ios, ps, ids); - return tac.execute(n); + goals const & gs = ps.get_goals(); + if (empty(gs)) + return none_proof_state(); + goal g = head(gs); + goals tail_gs = tail(gs); + name_generator ngen = ps.get_ngen(); + std::unique_ptr tc = mk_type_checker(env, ngen.mk_child(), ps.relax_main_opaque()); + inversion_tac tac(env, ios, ngen, *tc, ps.get_subst(), ids); + if (auto res = tac.execute(g, n, implementation_list())) { + proof_state new_s(ps, append(res->m_goals, tail_gs), res->m_subst, res->m_ngen); + return some_proof_state(new_s); + } else { + return none_proof_state(); + } }; return tactic01(fn); } diff --git a/src/library/tactic/inversion_tactic.h b/src/library/tactic/inversion_tactic.h index 83864ee31..4a4f30162 100644 --- a/src/library/tactic/inversion_tactic.h +++ b/src/library/tactic/inversion_tactic.h @@ -5,9 +5,57 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #pragma once +#include +#include #include "library/tactic/tactic.h" namespace lean { +namespace inversion { +/** + \brief When we apply the inversion tactic/procedure on a hypothesis (h : I A j), where + I is an inductive datatpe, the hypothesis is "broken" into cases: one for each constructor. + Some cases may be in conflict with the type (I A j) and may be suppressed. + + Example of conflict: given the vector type + inductive vector (A : Type) : nat → Type := + nil {} : vector A zero, + cons : Π {n : nat}, A → vector A n → vector A (succ n) + Then, (h : vector A (succ n)) is in conflict with constructor nil. + + The user may provide possible implementations (example: in the form of equations). + Each possible implementation is associated with a case/constructor. + + The inversion tactic/procedure distributes the implementations over cases. + + The implementations may depend on hypotheses that may be modifed by the inversion procedure. + The method update_exprs of each implementation is invoked to update the expressions of + a given implementation. +*/ +class implementation { +public: + virtual ~implementation() {} + virtual name const & get_constructor_name() const = 0; + virtual void update_exprs(std::function const & fn) = 0; +}; + +typedef std::shared_ptr implementation_ptr; +typedef list implementation_list; + +struct result { + list m_goals; + list m_num_args; + list m_implementation_lists; + // invariant: length(m_goals) == length(m_num_args); + // invariant: length(m_goals) == length(m_implementation_lists); + name_generator m_ngen; + substitution m_subst; + result(list const & gs, list const & num_args, list const & imps, + name_generator const & ngen, substitution const & subst); +}; + +optional apply(environment const & env, io_state const & ios, type_checker & tc, + goal const & g, name const & n, implementation_list const & imps); +} tactic inversion_tactic(name const & n, list const & ids = list()); void initialize_inversion_tactic(); void finalize_inversion_tactic();