feat(library/tactic/inversion_tactic): generate auxiliary information

This commit is contained in:
Leonardo de Moura 2014-12-31 18:55:18 -08:00
parent e76ef18980
commit 761810f350
2 changed files with 78 additions and 47 deletions

View file

@ -20,12 +20,13 @@ Author: Leonardo de Moura
namespace lean { namespace lean {
namespace inversion { namespace inversion {
result::result(list<goal> const & gs, list<unsigned> const & num_args, list<implementation_list> const & imps, result::result(list<goal> const & gs, list<list<name>> const & args, list<implementation_list> const & imps,
name_generator const & ngen, substitution const & subst): list<rename_map> const & rs, name_generator const & ngen, substitution const & subst):
m_goals(gs), m_num_args(num_args), m_implementation_lists(imps), m_goals(gs), m_args(args), m_implementation_lists(imps),
m_ngen(ngen), m_subst(subst) { m_renames(rs), m_ngen(ngen), m_subst(subst) {
lean_assert_eq(length(m_goals), length(m_num_args)); lean_assert_eq(length(m_goals), length(m_args));
lean_assert_eq(length(m_goals), length(m_implementation_lists)); lean_assert_eq(length(m_goals), length(m_implementation_lists));
lean_assert_eq(length(m_goals), length(m_renames));
} }
} }
@ -75,7 +76,6 @@ optional<expr> apply_eq_rec_eq(type_checker & tc, io_state const & ios, list<exp
typedef inversion::implementation_ptr implementation_ptr; typedef inversion::implementation_ptr implementation_ptr;
typedef inversion::implementation_list implementation_list; typedef inversion::implementation_list implementation_list;
static void abstract_locals(implementation_list const & imps, unsigned nlocals, expr const * locals) { static void abstract_locals(implementation_list const & imps, unsigned nlocals, expr const * locals) {
for (implementation_ptr const & imp : imps) { for (implementation_ptr const & imp : imps) {
imp->update_exprs([&](expr const & e) { return abstract_locals(e, nlocals, locals); }); imp->update_exprs([&](expr const & e) { return abstract_locals(e, nlocals, locals); });
@ -308,11 +308,12 @@ class inversion_tac {
} }
} }
std::pair<list<goal>, list<unsigned>> intros_minors_args(list<goal> gs) { std::pair<list<goal>, list<list<name>>> intros_minors_args(list<goal> gs) {
buffer<unsigned> minors_nargs; buffer<unsigned> minors_nargs;
get_minors_nargs(minors_nargs); get_minors_nargs(minors_nargs);
lean_assert(length(gs) == minors_nargs.size()); lean_assert(length(gs) == minors_nargs.size());
buffer<goal> new_gs; buffer<goal> new_gs;
buffer<list<name>> new_args;
for (unsigned i = 0; i < minors_nargs.size(); i++) { for (unsigned i = 0; i < minors_nargs.size(); i++) {
goal const & g = head(gs); goal const & g = head(gs);
unsigned nargs = minors_nargs[i]; unsigned nargs = minors_nargs[i];
@ -321,6 +322,7 @@ class inversion_tac {
buffer<expr> new_hyps; buffer<expr> new_hyps;
new_hyps.append(hyps); new_hyps.append(hyps);
expr g_type = g.get_type(); expr g_type = g.get_type();
buffer<name> curr_new_args;
for (unsigned i = 0; i < nargs; i++) { for (unsigned i = 0; i < nargs; i++) {
expr type = binding_domain(g_type); expr type = binding_domain(g_type);
name new_h_name; name new_h_name;
@ -331,9 +333,11 @@ class inversion_tac {
new_h_name = binding_name(g_type); new_h_name = binding_name(g_type);
} }
expr new_h = mk_local(m_ngen.next(), get_unused_name(new_h_name, new_hyps), type, binder_info()); expr new_h = mk_local(m_ngen.next(), get_unused_name(new_h_name, new_hyps), type, binder_info());
curr_new_args.push_back(mlocal_name(new_h));
new_hyps.push_back(new_h); new_hyps.push_back(new_h);
g_type = instantiate(binding_body(g_type), new_h); g_type = instantiate(binding_body(g_type), new_h);
} }
new_args.push_back(to_list(curr_new_args));
g_type = head_beta_reduce(g_type); g_type = head_beta_reduce(g_type);
expr new_meta = mk_app(mk_metavar(m_ngen.next(), Pi(new_hyps, g_type)), new_hyps); expr new_meta = mk_app(mk_metavar(m_ngen.next(), Pi(new_hyps, g_type)), new_hyps);
goal new_g(new_meta, g_type); goal new_g(new_meta, g_type);
@ -342,7 +346,7 @@ class inversion_tac {
assign(g.get_name(), val); assign(g.get_name(), val);
gs = tail(gs); gs = tail(gs);
} }
return mk_pair(to_list(new_gs), to_list(minors_nargs)); return mk_pair(to_list(new_gs), to_list(new_args));
} }
struct inversion_exception : public exception { struct inversion_exception : public exception {
@ -514,11 +518,21 @@ class inversion_tac {
} }
} }
typedef optional<std::pair<goal, implementation_list>> unify_result; rename_map m_renames;
implementation_list m_imps;
unify_result unify_eqs(goal g, implementation_list imps, unsigned neqs) { // update m_renames with old_hyps --> new_hyps.
void store_renames(buffer<expr> const & old_hyps, buffer<expr> const & new_hyps) {
lean_assert(old_hyps.size() == new_hyps.size());
for (unsigned i = 0; i < old_hyps.size(); i++) {
m_renames.insert(mlocal_name(old_hyps[i]), mlocal_name(new_hyps[i]));
}
}
// Remark: it also updates m_renames and m_imps
optional<goal> unify_eqs(goal g, unsigned neqs) {
if (neqs == 0) if (neqs == 0)
return unify_result(g, imps); // done return optional<goal>(g); // done
g = intro_next_eq(g); g = intro_next_eq(g);
buffer<expr> hyps; buffer<expr> hyps;
g.get_hyps(hyps); g.get_hyps(hyps);
@ -538,7 +552,7 @@ class inversion_tac {
goal new_g(new_meta, new_type); goal new_g(new_meta, new_type);
expr val = g.abstract(new_meta); expr val = g.abstract(new_meta);
assign(g.get_name(), val); assign(g.get_name(), val);
return unify_eqs(new_g, imps, neqs-1); return unify_eqs(new_g, neqs-1);
} }
buffer<expr> lhs_args, rhs_args; buffer<expr> lhs_args, rhs_args;
expr const & lhs_fn = get_app_args(lhs, lhs_args); expr const & lhs_fn = get_app_args(lhs, lhs_args);
@ -569,12 +583,12 @@ class inversion_tac {
assign(g.get_name(), val); assign(g.get_name(), val);
unsigned A_nparams = *inductive::get_num_params(m_env, const_name(A_fn)); unsigned A_nparams = *inductive::get_num_params(m_env, const_name(A_fn));
lean_assert(lhs_args.size() >= A_nparams); lean_assert(lhs_args.size() >= A_nparams);
return unify_eqs(new_g, imps, neqs - 1 + lhs_args.size() - A_nparams); return unify_eqs(new_g, neqs - 1 + lhs_args.size() - A_nparams);
} else { } else {
// conflict transition, eq is of the form c_1 ... = c_2 ..., where c_1 and c_2 are different constructors/intro rules. // conflict transition, eq is of the form c_1 ... = c_2 ..., where c_1 and c_2 are different constructors/intro rules.
expr val = g.abstract(lift_down(no_confusion)); expr val = g.abstract(lift_down(no_confusion));
assign(g.get_name(), val); assign(g.get_name(), val);
return unify_result(); // goal has been solved return optional<goal>(); // goal has been solved
} }
} }
if (is_local(rhs)) { if (is_local(rhs)) {
@ -602,7 +616,7 @@ class inversion_tac {
buffer<expr> non_deps, deps; buffer<expr> non_deps, deps;
split_deps(hyps, rhs, non_deps, deps); split_deps(hyps, rhs, non_deps, deps);
expr deps_g_type = Pi(deps, g_type); expr deps_g_type = Pi(deps, g_type);
abstract_locals(imps, deps); abstract_locals(m_imps, deps);
level eq_rec_lvl1 = sort_level(m_tc.ensure_type(deps_g_type).first); level eq_rec_lvl1 = sort_level(m_tc.ensure_type(deps_g_type).first);
level eq_rec_lvl2 = sort_level(m_tc.ensure_type(A).first); level eq_rec_lvl2 = sort_level(m_tc.ensure_type(A).first);
expr tformer; expr tformer;
@ -615,19 +629,23 @@ class inversion_tac {
buffer<expr> new_hyps; buffer<expr> new_hyps;
new_hyps.append(non_deps); new_hyps.append(non_deps);
expr new_type = instantiate(abstract_local(deps_g_type, rhs), lhs); expr new_type = instantiate(abstract_local(deps_g_type, rhs), lhs);
abstract_local(imps, rhs); abstract_local(m_imps, rhs);
instantiate(imps, lhs); instantiate(m_imps, lhs);
if (!m_proof_irrel) { if (!m_proof_irrel) {
new_type = abstract_local(new_type, eq); new_type = abstract_local(new_type, eq);
new_type = instantiate(new_type, mk_refl(m_tc, lhs)); new_type = instantiate(new_type, mk_refl(m_tc, lhs));
} }
buffer<expr> new_deps;
for (unsigned i = 0; i < deps.size(); i++) { for (unsigned i = 0; i < deps.size(); i++) {
expr new_hyp = mk_local(m_ngen.next(), binding_name(new_type), binding_domain(new_type), expr new_hyp = mk_local(m_ngen.next(), binding_name(new_type), binding_domain(new_type),
binding_info(new_type)); binding_info(new_type));
new_hyps.push_back(new_hyp); new_hyps.push_back(new_hyp);
new_deps.push_back(new_hyp);
new_type = instantiate(binding_body(new_type), new_hyp); new_type = instantiate(binding_body(new_type), new_hyp);
instantiate(imps, new_hyp); instantiate(m_imps, new_hyp);
} }
lean_assert(deps.size() == new_deps.size());
store_renames(deps, new_deps);
expr new_mvar = mk_metavar(m_ngen.next(), Pi(new_hyps, new_type)); expr new_mvar = mk_metavar(m_ngen.next(), Pi(new_hyps, new_type));
expr new_meta = mk_app(new_mvar, new_hyps); expr new_meta = mk_app(new_mvar, new_hyps);
goal new_g(new_meta, new_type); goal new_g(new_meta, new_type);
@ -635,7 +653,7 @@ class inversion_tac {
eq_rec = mk_app(eq_rec, eq_rec_minor, rhs, eq); eq_rec = mk_app(eq_rec, eq_rec_minor, rhs, eq);
expr val = g.abstract(mk_app(eq_rec, deps)); expr val = g.abstract(mk_app(eq_rec, deps));
assign(g.get_name(), val); assign(g.get_name(), val);
return unify_eqs(new_g, imps, neqs-1); return unify_eqs(new_g, neqs-1);
} else if (is_local(lhs)) { } else if (is_local(lhs)) {
// flip equation and reduce to previous case // flip equation and reduce to previous case
if (m_proof_irrel) if (m_proof_irrel)
@ -650,28 +668,33 @@ class inversion_tac {
symm_pr = mk_app(symm_pr, A, lhs, rhs, eq); symm_pr = mk_app(symm_pr, A, lhs, rhs, eq);
expr val = g.abstract(mk_app(new_meta, symm_pr)); expr val = g.abstract(mk_app(new_meta, symm_pr));
assign(g.get_name(), val); assign(g.get_name(), val);
return unify_eqs(new_g, imps, neqs); return unify_eqs(new_g, neqs);
} }
throw inversion_exception("unification failed"); throw inversion_exception("unification failed");
} }
auto unify_eqs(list<goal> const & gs, list<unsigned> nargs, list<implementation_list> imps) -> auto unify_eqs(list<goal> const & gs, list<list<name>> args, list<implementation_list> imps) ->
std::tuple<list<goal>, list<unsigned>, list<implementation_list>> { std::tuple<list<goal>, list<list<name>>, list<implementation_list>, list<rename_map>> {
lean_assert(length(gs) == length(imps)); lean_assert(length(gs) == length(imps));
unsigned neqs = m_nindices + (m_dep_elim ? 1 : 0); unsigned neqs = m_nindices + (m_dep_elim ? 1 : 0);
buffer<goal> new_goals; buffer<goal> new_goals;
buffer<unsigned> new_nargs; buffer<list<name>> new_args;
buffer<implementation_list> new_imps; buffer<implementation_list> new_imps;
buffer<rename_map> new_renames;
for (goal const & g : gs) { for (goal const & g : gs) {
if (auto g_imp_pair = unify_eqs(g, head(imps), neqs)) { flet<rename_map> set1(m_renames, rename_map());
new_goals.push_back(g_imp_pair->first); flet<implementation_list> set2(m_imps, head(imps));
new_nargs.push_back(head(nargs)); if (optional<goal> new_g = unify_eqs(g, neqs)) {
new_imps.push_back(g_imp_pair->second); new_goals.push_back(*new_g);
list<name> new_as = map(head(args), [&](name const & n) { return m_renames.find(n); });
new_args.push_back(new_as);
new_imps.push_back(m_imps);
new_renames.push_back(m_renames);
} }
imps = tail(imps); imps = tail(imps);
nargs = tail(nargs); args = tail(args);
} }
return std::make_tuple(to_list(new_goals), to_list(new_nargs), to_list(new_imps)); return std::make_tuple(to_list(new_goals), to_list(new_args), to_list(new_imps), to_list(new_renames));
} }
public: public:
@ -685,15 +708,10 @@ public:
inversion_tac(environment const & env, io_state const & ios, type_checker & tc): inversion_tac(environment const & env, io_state const & ios, type_checker & tc):
inversion_tac(env, ios, tc.mk_ngen(), tc, substitution(), list<name>()) {} inversion_tac(env, ios, tc.mk_ngen(), tc, substitution(), list<name>()) {}
typedef inversion::result result; typedef inversion::result result;
optional<result> execute(goal const & g, name const & n, implementation_list const & imps) { optional<result> execute(goal const & g, expr const & h, implementation_list const & imps) {
try { try {
auto p = g.find_hyp(n);
if (!p)
return optional<result>();
expr const & h = p->first;
expr h_type = m_tc.whnf(mlocal_type(h)).first; expr h_type = m_tc.whnf(mlocal_type(h)).first;
if (!is_inversion_applicable(h_type)) if (!is_inversion_applicable(h_type))
return optional<result>(); return optional<result>();
@ -701,22 +719,31 @@ public:
auto gs_imps_pair = apply_cases_on(g1, imps); auto gs_imps_pair = apply_cases_on(g1, imps);
list<goal> gs2 = gs_imps_pair.first; list<goal> gs2 = gs_imps_pair.first;
list<implementation_list> new_imps = gs_imps_pair.second; list<implementation_list> new_imps = gs_imps_pair.second;
auto gs_nargs_pair = intros_minors_args(gs2); auto gs_args_pair = intros_minors_args(gs2);
list<goal> gs3 = gs_nargs_pair.first; list<goal> gs3 = gs_args_pair.first;
list<unsigned> nargs = gs_nargs_pair.second; list<list<name>> args = gs_args_pair.second;
list<goal> gs4; list<goal> gs4;
std::tie(gs4, nargs, new_imps) = unify_eqs(gs3, nargs, new_imps); list<rename_map> rs;
return optional<result>(result(gs4, nargs, new_imps, m_ngen, m_subst)); std::tie(gs4, args, new_imps, rs) = unify_eqs(gs3, args, new_imps);
return optional<result>(result(gs4, args, new_imps, rs, m_ngen, m_subst));
} catch (inversion_exception & ex) { } catch (inversion_exception & ex) {
return optional<result>(); return optional<result>();
} }
} }
optional<result> execute(goal const & g, name const & n, implementation_list const & imps) {
auto p = g.find_hyp(n);
if (!p)
return optional<result>();
expr const & h = p->first;
return execute(g, h, imps);
}
}; };
namespace inversion { namespace inversion {
optional<result> apply(environment const & env, io_state const & ios, type_checker & tc, optional<result> apply(environment const & env, io_state const & ios, type_checker & tc,
goal const & g, name const & n, implementation_list const & imps) { goal const & g, expr const & h, implementation_list const & imps) {
return inversion_tac(env, ios, tc).execute(g, n, imps); return inversion_tac(env, ios, tc).execute(g, h, imps);
} }
} }

View file

@ -7,6 +7,7 @@ Author: Leonardo de Moura
#pragma once #pragma once
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "util/name_map.h"
#include "library/tactic/tactic.h" #include "library/tactic/tactic.h"
namespace lean { namespace lean {
@ -43,19 +44,22 @@ typedef list<implementation_ptr> implementation_list;
struct result { struct result {
list<goal> m_goals; list<goal> m_goals;
list<unsigned> m_num_args; list<list<name>> m_args; // arguments of the constructor/intro rule
list<implementation_list> m_implementation_lists; list<implementation_list> m_implementation_lists;
// invariant: length(m_goals) == length(m_num_args); list<rename_map> m_renames;
// invariant: length(m_goals) == length(m_args);
// invariant: length(m_goals) == length(m_implementation_lists); // invariant: length(m_goals) == length(m_implementation_lists);
// invariant: length(m_goals) == length(m_renames);
name_generator m_ngen; name_generator m_ngen;
substitution m_subst; substitution m_subst;
result(list<goal> const & gs, list<unsigned> const & num_args, list<implementation_list> const & imps, result(list<goal> const & gs, list<list<name>> const & args, list<implementation_list> const & imps,
name_generator const & ngen, substitution const & subst); list<rename_map> const & rs, name_generator const & ngen, substitution const & subst);
}; };
optional<result> apply(environment const & env, io_state const & ios, type_checker & tc, optional<result> apply(environment const & env, io_state const & ios, type_checker & tc,
goal const & g, name const & n, implementation_list const & imps); goal const & g, expr const & h, implementation_list const & imps);
} }
tactic inversion_tactic(name const & n, list<name> const & ids = list<name>()); tactic inversion_tactic(name const & n, list<name> const & ids = list<name>());
void initialize_inversion_tactic(); void initialize_inversion_tactic();
void finalize_inversion_tactic(); void finalize_inversion_tactic();