feat(library/tactic/inversion_tactic): add inversion::apply procedure

The new procedure is essentially a "customized" version of the
inversion (aka cases) tactic for the equations package
This commit is contained in:
Leonardo de Moura 2014-12-30 21:22:50 -08:00
parent f370871574
commit 1f13bfa4f7
2 changed files with 202 additions and 75 deletions

View file

@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura Author: Leonardo de Moura
*/ */
#include <utility>
#include "util/sstream.h" #include "util/sstream.h"
#include "kernel/abstract.h" #include "kernel/abstract.h"
#include "kernel/instantiate.h" #include "kernel/instantiate.h"
@ -15,8 +16,19 @@ Author: Leonardo de Moura
#include "library/tactic/tactic.h" #include "library/tactic/tactic.h"
#include "library/tactic/expr_to_tactic.h" #include "library/tactic/expr_to_tactic.h"
#include "library/tactic/class_instance_synth.h" #include "library/tactic/class_instance_synth.h"
#include "library/tactic/inversion_tactic.h"
namespace lean { namespace lean {
namespace inversion {
result::result(list<goal> const & gs, list<unsigned> const & num_args, list<implementation_list> const & imps,
name_generator const & ngen, substitution const & subst):
m_goals(gs), m_num_args(num_args), m_implementation_lists(imps),
m_ngen(ngen), m_subst(subst) {
lean_assert_eq(length(m_goals), length(m_num_args));
lean_assert_eq(length(m_goals), length(m_implementation_lists));
}
}
/** /**
\brief Given eq_rec of the form <tt>@eq.rec.{l l} A a (λ (a' : A) (h : a = a'), B a') b a p</tt>, \brief Given eq_rec of the form <tt>@eq.rec.{l l} A a (λ (a' : A) (h : a = a'), B a') b a p</tt>,
apply the eq_rec_eq definition to produce the equality apply the eq_rec_eq definition to produce the equality
@ -61,14 +73,36 @@ optional<expr> apply_eq_rec_eq(type_checker & tc, io_state const & ios, list<exp
return some_expr(mk_app({r, A, B, *is_hset_A, a, b, p})); return some_expr(mk_app({r, A, B, *is_hset_A, a, b, p}));
} }
typedef inversion::implementation_ptr implementation_ptr;
typedef inversion::implementation_list implementation_list;
static void abstract_locals(implementation_list const & imps, unsigned nlocals, expr const * locals) {
for (implementation_ptr const & imp : imps) {
imp->update_exprs([&](expr const & e) { return abstract_locals(e, nlocals, locals); });
}
}
static void instantiate(implementation_list const & imps, expr const & local) {
for (implementation_ptr const & imp : imps) {
imp->update_exprs([&](expr const & e) { return instantiate(e, local); });
}
}
static void abstract_locals(implementation_list const & imps, buffer<expr> const & locals) {
abstract_locals(imps, locals.size(), locals.data());
}
static void abstract_local(implementation_list const & imps, expr const & local) {
abstract_locals(imps, 1, &local);
}
class inversion_tac { class inversion_tac {
environment const & m_env; environment const & m_env;
io_state const & m_ios; io_state const & m_ios;
proof_state const & m_ps; type_checker & m_tc;
list<name> m_ids; list<name> m_ids;
name_generator m_ngen; name_generator m_ngen;
substitution m_subst; substitution m_subst;
std::unique_ptr<type_checker> m_tc;
bool m_dep_elim; bool m_dep_elim;
bool m_proof_irrel; bool m_proof_irrel;
@ -109,11 +143,11 @@ class inversion_tac {
} }
pair<expr, expr> mk_eq(expr const & lhs, expr const & rhs) { pair<expr, expr> mk_eq(expr const & lhs, expr const & rhs) {
expr lhs_type = m_tc->infer(lhs).first; expr lhs_type = m_tc.infer(lhs).first;
expr rhs_type = m_tc->infer(rhs).first; expr rhs_type = m_tc.infer(rhs).first;
level l = sort_level(m_tc->ensure_type(lhs_type).first); level l = sort_level(m_tc.ensure_type(lhs_type).first);
constraint_seq cs; constraint_seq cs;
if (m_tc->is_def_eq(lhs_type, rhs_type, justification(), cs) && !cs) { if (m_tc.is_def_eq(lhs_type, rhs_type, justification(), cs) && !cs) {
return mk_pair(mk_app(mk_constant("eq", to_list(l)), lhs_type, lhs, rhs), return mk_pair(mk_app(mk_constant("eq", to_list(l)), lhs_type, lhs, rhs),
mk_app(mk_constant({"eq", "refl"}, to_list(l)), rhs_type, rhs)); mk_app(mk_constant({"eq", "refl"}, to_list(l)), rhs_type, rhs));
} else { } else {
@ -135,7 +169,7 @@ class inversion_tac {
buffer<expr> I_args; buffer<expr> I_args;
expr const & I = get_app_args(h_type, I_args); expr const & I = get_app_args(h_type, I_args);
expr h_new_type = mk_app(I, I_args.size() - m_nindices, I_args.data()); expr h_new_type = mk_app(I, I_args.size() - m_nindices, I_args.data());
expr d = m_tc->whnf(m_tc->infer(h_new_type).first).first; expr d = m_tc.whnf(m_tc.infer(h_new_type).first).first;
name t_prefix("t"); name t_prefix("t");
unsigned nidx = 1; unsigned nidx = 1;
if (m_proof_irrel) { if (m_proof_irrel) {
@ -182,7 +216,7 @@ class inversion_tac {
expr t = mk_local(m_ngen.next(), g.get_unused_name(t_prefix, nidx), t_type, binder_info()); expr t = mk_local(m_ngen.next(), g.get_unused_name(t_prefix, nidx), t_type, binder_info());
h_new_type = mk_app(h_new_type, t); h_new_type = mk_app(h_new_type, t);
ss.push_back(I_args[i]); ss.push_back(I_args[i]);
refls.push_back(mk_refl(*m_tc, I_args[i])); refls.push_back(mk_refl(m_tc, I_args[i]));
hyps.push_back(t); hyps.push_back(t);
ts.push_back(t); ts.push_back(t);
d = instantiate(binding_body(d), t); d = instantiate(binding_body(d), t);
@ -190,10 +224,10 @@ class inversion_tac {
expr h_new = mk_local(m_ngen.next(), h_new_name, h_new_type, local_info(h)); expr h_new = mk_local(m_ngen.next(), h_new_name, h_new_type, local_info(h));
ts.push_back(h_new); ts.push_back(h_new);
ss.push_back(h); ss.push_back(h);
refls.push_back(mk_refl(*m_tc, h)); refls.push_back(mk_refl(m_tc, h));
hyps.push_back(h_new); hyps.push_back(h_new);
buffer<expr> eqs; buffer<expr> eqs;
mk_telescopic_eq(*m_tc, ss, ts, eqs); mk_telescopic_eq(m_tc, ss, ts, eqs);
ts.pop_back(); ts.pop_back();
expr new_type = Pi(eqs, g.get_type()); expr new_type = Pi(eqs, g.get_type());
expr new_meta = mk_app(mk_metavar(m_ngen.next(), Pi(hyps, new_type)), hyps); expr new_meta = mk_app(mk_metavar(m_ngen.next(), Pi(hyps, new_type)), hyps);
@ -205,20 +239,21 @@ class inversion_tac {
} }
} }
list<goal> apply_cases_on(goal const & g) { std::pair<list<goal>, list<implementation_list>> apply_cases_on(goal const & g, implementation_list const & imps) {
buffer<expr> hyps; buffer<expr> hyps;
g.get_hyps(hyps); g.get_hyps(hyps);
expr const & h = hyps.back(); expr const & h = hyps.back();
expr const & h_type = mlocal_type(h); expr const & h_type = mlocal_type(h);
buffer<expr> I_args; buffer<expr> I_args;
expr const & I = get_app_args(h_type, I_args); expr const & I = get_app_args(h_type, I_args);
name const & I_name = const_name(I);
expr g_type = g.get_type(); expr g_type = g.get_type();
expr cases_on; expr cases_on;
if (length(m_cases_on_decl.get_univ_params()) != length(m_I_decl.get_univ_params())) { if (length(m_cases_on_decl.get_univ_params()) != length(m_I_decl.get_univ_params())) {
level g_lvl = sort_level(m_tc->ensure_type(g_type).first); level g_lvl = sort_level(m_tc.ensure_type(g_type).first);
cases_on = mk_constant({const_name(I), "cases_on"}, cons(g_lvl, const_levels(I))); cases_on = mk_constant({I_name, "cases_on"}, cons(g_lvl, const_levels(I)));
} else { } else {
cases_on = mk_constant({const_name(I), "cases_on"}, const_levels(I)); cases_on = mk_constant({I_name, "cases_on"}, const_levels(I));
} }
// add params // add params
cases_on = mk_app(cases_on, m_nparams, I_args.data()); cases_on = mk_app(cases_on, m_nparams, I_args.data());
@ -232,22 +267,28 @@ class inversion_tac {
cases_on = mk_app(cases_on, m_nindices, I_args.end() - m_nindices); cases_on = mk_app(cases_on, m_nindices, I_args.end() - m_nindices);
// add h // add h
cases_on = mk_app(cases_on, h); cases_on = mk_app(cases_on, h);
buffer<name> intro_names;
get_intro_rule_names(m_env, I_name, intro_names);
lean_assert(m_nminors == intro_names.size());
buffer<expr> new_hyps; buffer<expr> new_hyps;
new_hyps.append(hyps.size() - m_nindices - 1, hyps.data()); new_hyps.append(hyps.size() - m_nindices - 1, hyps.data());
// add a subgoal for each minor premise of cases_on // add a subgoal for each minor premise of cases_on
expr cases_on_type = m_tc->whnf(m_tc->infer(cases_on).first).first; expr cases_on_type = m_tc.whnf(m_tc.infer(cases_on).first).first;
buffer<goal> new_goals; buffer<goal> new_goals;
buffer<implementation_list> new_imps;
for (unsigned i = 0; i < m_nminors; i++) { for (unsigned i = 0; i < m_nminors; i++) {
expr new_type = binding_domain(cases_on_type); expr new_type = binding_domain(cases_on_type);
expr new_meta = mk_app(mk_metavar(m_ngen.next(), Pi(new_hyps, new_type)), new_hyps); expr new_meta = mk_app(mk_metavar(m_ngen.next(), Pi(new_hyps, new_type)), new_hyps);
goal new_g(new_meta, new_type); goal new_g(new_meta, new_type);
new_goals.push_back(new_g); new_goals.push_back(new_g);
cases_on = mk_app(cases_on, new_meta); cases_on = mk_app(cases_on, new_meta);
cases_on_type = m_tc->whnf(binding_body(cases_on_type)).first; // the minor premises do not depend on each other cases_on_type = m_tc.whnf(binding_body(cases_on_type)).first; // the minor premises do not depend on each other
name const & intro_name = intro_names[i];
new_imps.push_back(filter(imps, [&](implementation_ptr const & imp) { return imp->get_constructor_name() == intro_name; }));
} }
expr val = g.abstract(cases_on); expr val = g.abstract(cases_on);
assign(g.get_name(), val); assign(g.get_name(), val);
return to_list(new_goals.begin(), new_goals.end()); return mk_pair(to_list(new_goals), to_list(new_imps));
} }
// Store in \c r the number of arguments for each cases_on minor. // Store in \c r the number of arguments for each cases_on minor.
@ -267,7 +308,7 @@ class inversion_tac {
} }
} }
list<goal> intros_minors_args(list<goal> gs) { std::pair<list<goal>, list<unsigned>> intros_minors_args(list<goal> gs) {
buffer<unsigned> minors_nargs; buffer<unsigned> minors_nargs;
get_minors_nargs(minors_nargs); get_minors_nargs(minors_nargs);
lean_assert(length(gs) == minors_nargs.size()); lean_assert(length(gs) == minors_nargs.size());
@ -301,7 +342,7 @@ class inversion_tac {
assign(g.get_name(), val); assign(g.get_name(), val);
gs = tail(gs); gs = tail(gs);
} }
return to_list(new_gs.begin(), new_gs.end()); return mk_pair(to_list(new_gs), to_list(minors_nargs));
} }
struct inversion_exception : public exception { struct inversion_exception : public exception {
@ -338,19 +379,19 @@ class inversion_tac {
lean_assert(is_eq_rec(lhs)); lean_assert(is_eq_rec(lhs));
// lhs is of the form (eq.rec A s C a s p) // lhs is of the form (eq.rec A s C a s p)
// aux_eq is a term of type ((eq.rec A s C a s p) = a) // aux_eq is a term of type ((eq.rec A s C a s p) = a)
auto aux_eq = apply_eq_rec_eq(*m_tc, m_ios, to_list(hyps), lhs); auto aux_eq = apply_eq_rec_eq(m_tc, m_ios, to_list(hyps), lhs);
if (!aux_eq) if (!aux_eq)
throw_unification_eq_rec_failure(); throw_unification_eq_rec_failure();
buffer<expr> lhs_args; buffer<expr> lhs_args;
get_app_args(lhs, lhs_args); get_app_args(lhs, lhs_args);
expr const & reduced_lhs = lhs_args[3]; expr const & reduced_lhs = lhs_args[3];
expr new_eq = ::lean::mk_eq(*m_tc, reduced_lhs, rhs); expr new_eq = ::lean::mk_eq(m_tc, reduced_lhs, rhs);
expr new_type = update_binding(type, new_eq, binding_body(type)); expr new_type = update_binding(type, new_eq, binding_body(type));
expr new_meta = mk_app(mk_metavar(m_ngen.next(), Pi(hyps, new_type)), hyps); expr new_meta = mk_app(mk_metavar(m_ngen.next(), Pi(hyps, new_type)), hyps);
goal new_g(new_meta, new_type); goal new_g(new_meta, new_type);
// create assignment for g // create assignment for g
expr A = m_tc->infer(lhs).first; expr A = m_tc.infer(lhs).first;
level lvl = sort_level(m_tc->ensure_type(A).first); level lvl = sort_level(m_tc.ensure_type(A).first);
// old_eq : eq.rec A s C a s p = b // old_eq : eq.rec A s C a s p = b
expr old_eq = mk_local(m_ngen.next(), binding_name(type), eq, binder_info()); expr old_eq = mk_local(m_ngen.next(), binding_name(type), eq, binder_info());
// aux_eq : a = eq.rec A s C a s p // aux_eq : a = eq.rec A s C a s p
@ -373,7 +414,7 @@ class inversion_tac {
buffer<expr> args; buffer<expr> args;
expr const & heq_fn = get_app_args(eq, args); expr const & heq_fn = get_app_args(eq, args);
constraint_seq cs; constraint_seq cs;
if (m_tc->is_def_eq(args[0], args[2], justification(), cs) && !cs) { if (m_tc.is_def_eq(args[0], args[2], justification(), cs) && !cs) {
buffer<expr> hyps; buffer<expr> hyps;
g.get_hyps(hyps); g.get_hyps(hyps);
expr new_eq = mk_app(mk_constant("eq", const_levels(heq_fn)), args[0], args[1], args[3]); expr new_eq = mk_app(mk_constant("eq", const_levels(heq_fn)), args[0], args[1], args[3]);
@ -425,8 +466,8 @@ class inversion_tac {
if (const_name(eq_fn) == "eq") { if (const_name(eq_fn) == "eq") {
expr const & lhs = app_arg(app_fn(eq)); expr const & lhs = app_arg(app_fn(eq));
expr const & rhs = app_arg(eq); expr const & rhs = app_arg(eq);
expr new_lhs = m_tc->whnf(lhs).first; expr new_lhs = m_tc.whnf(lhs).first;
expr new_rhs = m_tc->whnf(rhs).first; expr new_rhs = m_tc.whnf(rhs).first;
if (lhs != new_lhs || rhs != new_rhs) { if (lhs != new_lhs || rhs != new_rhs) {
eq = mk_app(app_fn(app_fn(eq)), new_lhs, new_rhs); eq = mk_app(app_fn(app_fn(eq)), new_lhs, new_rhs);
type = update_binding(type, eq, binding_body(type)); type = update_binding(type, eq, binding_body(type));
@ -461,7 +502,7 @@ class inversion_tac {
// We must apply lift.down to eliminate the auxiliary lift. // We must apply lift.down to eliminate the auxiliary lift.
expr lift_down(expr const & v) { expr lift_down(expr const & v) {
if (!m_proof_irrel) { if (!m_proof_irrel) {
expr v_type = m_tc->whnf(m_tc->infer(v).first).first; expr v_type = m_tc.whnf(m_tc.infer(v).first).first;
if (!is_app(v_type)) if (!is_app(v_type))
throw_unification_eq_rec_failure(); throw_unification_eq_rec_failure();
expr const & lift = app_fn(v_type); expr const & lift = app_fn(v_type);
@ -473,9 +514,11 @@ class inversion_tac {
} }
} }
optional<goal> unify_eqs(goal g, unsigned neqs) { typedef optional<std::pair<goal, implementation_list>> unify_result;
unify_result unify_eqs(goal g, implementation_list imps, unsigned neqs) {
if (neqs == 0) if (neqs == 0)
return optional<goal>(g); // done return unify_result(g, imps); // done
g = intro_next_eq(g); g = intro_next_eq(g);
buffer<expr> hyps; buffer<expr> hyps;
g.get_hyps(hyps); g.get_hyps(hyps);
@ -483,11 +526,11 @@ class inversion_tac {
expr eq = hyps.back(); expr eq = hyps.back();
buffer<expr> eq_args; buffer<expr> eq_args;
get_app_args(mlocal_type(eq), eq_args); get_app_args(mlocal_type(eq), eq_args);
expr const & A = m_tc->whnf(eq_args[0]).first; expr const & A = m_tc.whnf(eq_args[0]).first;
expr lhs = m_tc->whnf(eq_args[1]).first; expr lhs = m_tc.whnf(eq_args[1]).first;
expr rhs = m_tc->whnf(eq_args[2]).first; expr rhs = m_tc.whnf(eq_args[2]).first;
constraint_seq cs; constraint_seq cs;
if (m_proof_irrel && m_tc->is_def_eq(lhs, rhs, justification(), cs) && !cs) { if (m_proof_irrel && m_tc.is_def_eq(lhs, rhs, justification(), cs) && !cs) {
// deletion transition: t == t // deletion transition: t == t
hyps.pop_back(); // remove t == t equality hyps.pop_back(); // remove t == t equality
expr new_type = g.get_type(); expr new_type = g.get_type();
@ -495,13 +538,13 @@ class inversion_tac {
goal new_g(new_meta, new_type); goal new_g(new_meta, new_type);
expr val = g.abstract(new_meta); expr val = g.abstract(new_meta);
assign(g.get_name(), val); assign(g.get_name(), val);
return unify_eqs(new_g, neqs-1); return unify_eqs(new_g, imps, neqs-1);
} }
buffer<expr> lhs_args, rhs_args; buffer<expr> lhs_args, rhs_args;
expr const & lhs_fn = get_app_args(lhs, lhs_args); expr const & lhs_fn = get_app_args(lhs, lhs_args);
expr const & rhs_fn = get_app_args(rhs, rhs_args); expr const & rhs_fn = get_app_args(rhs, rhs_args);
expr const & g_type = g.get_type(); expr const & g_type = g.get_type();
level const & g_lvl = sort_level(m_tc->ensure_type(g_type).first); level const & g_lvl = sort_level(m_tc.ensure_type(g_type).first);
if (is_constant(lhs_fn) && if (is_constant(lhs_fn) &&
is_constant(rhs_fn) && is_constant(rhs_fn) &&
inductive::is_intro_rule(m_env, const_name(lhs_fn)) && inductive::is_intro_rule(m_env, const_name(lhs_fn)) &&
@ -516,7 +559,7 @@ class inversion_tac {
expr no_confusion = mk_app(mk_app(mk_constant(no_confusion_name, cons(g_lvl, const_levels(A_fn))), A_args), g_type, lhs, rhs, eq); expr no_confusion = mk_app(mk_app(mk_constant(no_confusion_name, cons(g_lvl, const_levels(A_fn))), A_args), g_type, lhs, rhs, eq);
if (const_name(lhs_fn) == const_name(rhs_fn)) { if (const_name(lhs_fn) == const_name(rhs_fn)) {
// injectivity transition // injectivity transition
expr new_type = binding_domain(m_tc->whnf(m_tc->infer(no_confusion).first).first); expr new_type = binding_domain(m_tc.whnf(m_tc.infer(no_confusion).first).first);
if (m_proof_irrel) if (m_proof_irrel)
hyps.pop_back(); // remove processed equality hyps.pop_back(); // remove processed equality
expr new_mvar = mk_metavar(m_ngen.next(), Pi(hyps, new_type)); expr new_mvar = mk_metavar(m_ngen.next(), Pi(hyps, new_type));
@ -526,12 +569,12 @@ class inversion_tac {
assign(g.get_name(), val); assign(g.get_name(), val);
unsigned A_nparams = *inductive::get_num_params(m_env, const_name(A_fn)); unsigned A_nparams = *inductive::get_num_params(m_env, const_name(A_fn));
lean_assert(lhs_args.size() >= A_nparams); lean_assert(lhs_args.size() >= A_nparams);
return unify_eqs(new_g, neqs - 1 + lhs_args.size() - A_nparams); return unify_eqs(new_g, imps, neqs - 1 + lhs_args.size() - A_nparams);
} else { } else {
// conflict transition, eq is of the form c_1 ... = c_2 ..., where c_1 and c_2 are different constructors/intro rules. // conflict transition, eq is of the form c_1 ... = c_2 ..., where c_1 and c_2 are different constructors/intro rules.
expr val = g.abstract(lift_down(no_confusion)); expr val = g.abstract(lift_down(no_confusion));
assign(g.get_name(), val); assign(g.get_name(), val);
return optional<goal>(); // goal has been solved return unify_result(); // goal has been solved
} }
} }
if (is_local(rhs)) { if (is_local(rhs)) {
@ -559,8 +602,9 @@ class inversion_tac {
buffer<expr> non_deps, deps; buffer<expr> non_deps, deps;
split_deps(hyps, rhs, non_deps, deps); split_deps(hyps, rhs, non_deps, deps);
expr deps_g_type = Pi(deps, g_type); expr deps_g_type = Pi(deps, g_type);
level eq_rec_lvl1 = sort_level(m_tc->ensure_type(deps_g_type).first); abstract_locals(imps, deps);
level eq_rec_lvl2 = sort_level(m_tc->ensure_type(A).first); level eq_rec_lvl1 = sort_level(m_tc.ensure_type(deps_g_type).first);
level eq_rec_lvl2 = sort_level(m_tc.ensure_type(A).first);
expr tformer; expr tformer;
if (m_proof_irrel) if (m_proof_irrel)
tformer = Fun(rhs, deps_g_type); tformer = Fun(rhs, deps_g_type);
@ -571,15 +615,18 @@ class inversion_tac {
buffer<expr> new_hyps; buffer<expr> new_hyps;
new_hyps.append(non_deps); new_hyps.append(non_deps);
expr new_type = instantiate(abstract_local(deps_g_type, rhs), lhs); expr new_type = instantiate(abstract_local(deps_g_type, rhs), lhs);
abstract_local(imps, rhs);
instantiate(imps, lhs);
if (!m_proof_irrel) { if (!m_proof_irrel) {
new_type = abstract_local(new_type, eq); new_type = abstract_local(new_type, eq);
new_type = instantiate(new_type, mk_refl(*m_tc, lhs)); new_type = instantiate(new_type, mk_refl(m_tc, lhs));
} }
for (unsigned i = 0; i < deps.size(); i++) { for (unsigned i = 0; i < deps.size(); i++) {
expr new_hyp = mk_local(m_ngen.next(), binding_name(new_type), binding_domain(new_type), expr new_hyp = mk_local(m_ngen.next(), binding_name(new_type), binding_domain(new_type),
binding_info(new_type)); binding_info(new_type));
new_hyps.push_back(new_hyp); new_hyps.push_back(new_hyp);
new_type = instantiate(binding_body(new_type), new_hyp); new_type = instantiate(binding_body(new_type), new_hyp);
instantiate(imps, new_hyp);
} }
expr new_mvar = mk_metavar(m_ngen.next(), Pi(new_hyps, new_type)); expr new_mvar = mk_metavar(m_ngen.next(), Pi(new_hyps, new_type));
expr new_meta = mk_app(new_mvar, new_hyps); expr new_meta = mk_app(new_mvar, new_hyps);
@ -588,7 +635,7 @@ class inversion_tac {
eq_rec = mk_app(eq_rec, eq_rec_minor, rhs, eq); eq_rec = mk_app(eq_rec, eq_rec_minor, rhs, eq);
expr val = g.abstract(mk_app(eq_rec, deps)); expr val = g.abstract(mk_app(eq_rec, deps));
assign(g.get_name(), val); assign(g.get_name(), val);
return unify_eqs(new_g, neqs-1); return unify_eqs(new_g, imps, neqs-1);
} else if (is_local(lhs)) { } else if (is_local(lhs)) {
// flip equation and reduce to previous case // flip equation and reduce to previous case
if (m_proof_irrel) if (m_proof_irrel)
@ -598,65 +645,97 @@ class inversion_tac {
expr new_mvar = mk_metavar(m_ngen.next(), Pi(hyps, new_type)); expr new_mvar = mk_metavar(m_ngen.next(), Pi(hyps, new_type));
expr new_meta = mk_app(new_mvar, hyps); expr new_meta = mk_app(new_mvar, hyps);
goal new_g(new_meta, new_type); goal new_g(new_meta, new_type);
level eq_symm_lvl = sort_level(m_tc->ensure_type(A).first); level eq_symm_lvl = sort_level(m_tc.ensure_type(A).first);
expr symm_pr = mk_constant(name{"eq", "symm"}, {eq_symm_lvl}); expr symm_pr = mk_constant(name{"eq", "symm"}, {eq_symm_lvl});
symm_pr = mk_app(symm_pr, A, lhs, rhs, eq); symm_pr = mk_app(symm_pr, A, lhs, rhs, eq);
expr val = g.abstract(mk_app(new_meta, symm_pr)); expr val = g.abstract(mk_app(new_meta, symm_pr));
assign(g.get_name(), val); assign(g.get_name(), val);
return unify_eqs(new_g, neqs); return unify_eqs(new_g, imps, neqs);
} }
// unification failed throw inversion_exception("unification failed");
return optional<goal>(g);
} }
list<goal> unify_eqs(list<goal> const & gs) { auto unify_eqs(list<goal> const & gs, list<unsigned> nargs, list<implementation_list> imps) ->
std::tuple<list<goal>, list<unsigned>, list<implementation_list>> {
lean_assert(length(gs) == length(imps));
unsigned neqs = m_nindices + (m_dep_elim ? 1 : 0); unsigned neqs = m_nindices + (m_dep_elim ? 1 : 0);
buffer<goal> new_goals; buffer<goal> new_goals;
buffer<unsigned> new_nargs;
buffer<implementation_list> new_imps;
for (goal const & g : gs) { for (goal const & g : gs) {
if (optional<goal> new_g = unify_eqs(g, neqs)) if (auto g_imp_pair = unify_eqs(g, head(imps), neqs)) {
new_goals.push_back(*new_g); new_goals.push_back(g_imp_pair->first);
new_nargs.push_back(head(nargs));
new_imps.push_back(g_imp_pair->second);
} }
return to_list(new_goals.begin(), new_goals.end()); imps = tail(imps);
nargs = tail(nargs);
}
return std::make_tuple(to_list(new_goals), to_list(new_nargs), to_list(new_imps));
} }
public: public:
inversion_tac(environment const & env, io_state const & ios, proof_state const & ps, list<name> const & ids): inversion_tac(environment const & env, io_state const & ios, name_generator const & ngen,
m_env(env), m_ios(ios), m_ps(ps), m_ids(ids), type_checker & tc, substitution const & subst, list<name> const & ids):
m_ngen(m_ps.get_ngen()), m_subst(m_ps.get_subst()), m_env(env), m_ios(ios), m_tc(tc), m_ids(ids),
m_tc(mk_type_checker(m_env, m_ngen.mk_child(), m_ps.relax_main_opaque())) { m_ngen(ngen), m_subst(subst) {
m_proof_irrel = m_env.prop_proof_irrel(); m_proof_irrel = m_env.prop_proof_irrel();
} }
optional<proof_state> execute(name const & n) { inversion_tac(environment const & env, io_state const & ios, type_checker & tc):
inversion_tac(env, ios, tc.mk_ngen(), tc, substitution(), list<name>()) {}
typedef inversion::result result;
optional<result> execute(goal const & g, name const & n, implementation_list const & imps) {
try { try {
goals const & gs = m_ps.get_goals();
if (empty(gs))
return none_proof_state();
goal g = head(gs);
goals tail_gs = tail(gs);
auto p = g.find_hyp(n); auto p = g.find_hyp(n);
if (!p) if (!p)
return none_proof_state(); return optional<result>();
expr const & h = p->first; expr const & h = p->first;
expr h_type = m_tc->whnf(mlocal_type(h)).first; expr h_type = m_tc.whnf(mlocal_type(h)).first;
if (!is_inversion_applicable(h_type)) if (!is_inversion_applicable(h_type))
return none_proof_state(); return optional<result>();
goal g1 = generalize_indices(g, h, h_type); goal g1 = generalize_indices(g, h, h_type);
list<goal> gs2 = apply_cases_on(g1); auto gs_imps_pair = apply_cases_on(g1, imps);
list<goal> gs3 = intros_minors_args(gs2); list<goal> gs2 = gs_imps_pair.first;
list<goal> gs4 = unify_eqs(gs3); list<implementation_list> new_imps = gs_imps_pair.second;
proof_state new_s(m_ps, append(gs4, tail_gs), m_subst, m_ngen); auto gs_nargs_pair = intros_minors_args(gs2);
return some_proof_state(new_s); list<goal> gs3 = gs_nargs_pair.first;
list<unsigned> nargs = gs_nargs_pair.second;
list<goal> gs4;
std::tie(gs4, nargs, new_imps) = unify_eqs(gs3, nargs, new_imps);
return optional<result>(result(gs4, nargs, new_imps, m_ngen, m_subst));
} catch (inversion_exception & ex) { } catch (inversion_exception & ex) {
return none_proof_state(); return optional<result>();
} }
} }
}; };
namespace inversion {
optional<result> apply(environment const & env, io_state const & ios, type_checker & tc,
goal const & g, name const & n, implementation_list const & imps) {
return inversion_tac(env, ios, tc).execute(g, n, imps);
}
}
tactic inversion_tactic(name const & n, list<name> const & ids) { tactic inversion_tactic(name const & n, list<name> const & ids) {
auto fn = [=](environment const & env, io_state const & ios, proof_state const & ps) -> optional<proof_state> { auto fn = [=](environment const & env, io_state const & ios, proof_state const & ps) -> optional<proof_state> {
inversion_tac tac(env, ios, ps, ids); goals const & gs = ps.get_goals();
return tac.execute(n); if (empty(gs))
return none_proof_state();
goal g = head(gs);
goals tail_gs = tail(gs);
name_generator ngen = ps.get_ngen();
std::unique_ptr<type_checker> tc = mk_type_checker(env, ngen.mk_child(), ps.relax_main_opaque());
inversion_tac tac(env, ios, ngen, *tc, ps.get_subst(), ids);
if (auto res = tac.execute(g, n, implementation_list())) {
proof_state new_s(ps, append(res->m_goals, tail_gs), res->m_subst, res->m_ngen);
return some_proof_state(new_s);
} else {
return none_proof_state();
}
}; };
return tactic01(fn); return tactic01(fn);
} }

View file

@ -5,9 +5,57 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura Author: Leonardo de Moura
*/ */
#pragma once #pragma once
#include <functional>
#include <memory>
#include "library/tactic/tactic.h" #include "library/tactic/tactic.h"
namespace lean { namespace lean {
namespace inversion {
/**
\brief When we apply the inversion tactic/procedure on a hypothesis (h : I A j), where
I is an inductive datatpe, the hypothesis is "broken" into cases: one for each constructor.
Some cases may be in conflict with the type (I A j) and may be suppressed.
Example of conflict: given the vector type
inductive vector (A : Type) : nat Type :=
nil {} : vector A zero,
cons : Π {n : nat}, A vector A n vector A (succ n)
Then, (h : vector A (succ n)) is in conflict with constructor nil.
The user may provide possible implementations (example: in the form of equations).
Each possible implementation is associated with a case/constructor.
The inversion tactic/procedure distributes the implementations over cases.
The implementations may depend on hypotheses that may be modifed by the inversion procedure.
The method update_exprs of each implementation is invoked to update the expressions of
a given implementation.
*/
class implementation {
public:
virtual ~implementation() {}
virtual name const & get_constructor_name() const = 0;
virtual void update_exprs(std::function<expr(expr const &)> const & fn) = 0;
};
typedef std::shared_ptr<implementation> implementation_ptr;
typedef list<implementation_ptr> implementation_list;
struct result {
list<goal> m_goals;
list<unsigned> m_num_args;
list<implementation_list> m_implementation_lists;
// invariant: length(m_goals) == length(m_num_args);
// invariant: length(m_goals) == length(m_implementation_lists);
name_generator m_ngen;
substitution m_subst;
result(list<goal> const & gs, list<unsigned> const & num_args, list<implementation_list> const & imps,
name_generator const & ngen, substitution const & subst);
};
optional<result> apply(environment const & env, io_state const & ios, type_checker & tc,
goal const & g, name const & n, implementation_list const & imps);
}
tactic inversion_tactic(name const & n, list<name> const & ids = list<name>()); tactic inversion_tactic(name const & n, list<name> const & ids = list<name>());
void initialize_inversion_tactic(); void initialize_inversion_tactic();
void finalize_inversion_tactic(); void finalize_inversion_tactic();