refactor(kernel,library,frontends/lean): add helper functions, and cleanup collect_locals

This commit is contained in:
Leonardo de Moura 2015-06-05 17:29:36 -07:00
parent 7db84c7036
commit d0d3f9bb41
12 changed files with 160 additions and 127 deletions

View file

@ -342,7 +342,7 @@ struct inductive_cmd_fn {
}
/** \brief Collect local constants used in the inductive decls. */
void collect_locals_core(buffer<inductive_decl> const & decls, expr_struct_set & ls) {
void collect_locals_core(buffer<inductive_decl> const & decls, collected_locals & ls) {
buffer<expr> include_vars;
m_p.get_include_variables(include_vars);
for (expr const & param : include_vars) {
@ -365,11 +365,11 @@ struct inductive_cmd_fn {
void collect_locals(buffer<inductive_decl> & decls, buffer<expr> & locals) {
if (!m_p.has_locals())
return;
expr_struct_set local_set;
collected_locals local_set;
collect_locals_core(decls, local_set);
if (local_set.empty())
return;
sort_locals(local_set, m_p, locals);
sort_locals(local_set.get_collected(), m_p, locals);
m_num_params += locals.size();
}

View file

@ -357,7 +357,7 @@ struct migrate_cmd_fn {
for (auto const & p : m_replacements)
tmp_locals.push_back(mk_local(m_ngen.next(), mk_as_type(p.second)));
expr_struct_set dep_set;
collected_locals dep_set;
for (expr const & v : include_vars) {
::lean::collect_locals(mlocal_type(v), dep_set);
dep_set.insert(v);
@ -365,7 +365,7 @@ struct migrate_cmd_fn {
for (expr const & p : m_params)
::lean::collect_locals(mlocal_type(p), dep_set);
buffer<expr> ctx;
sort_locals(dep_set, m_p, ctx);
sort_locals(dep_set.get_collected(), m_p, ctx);
expr dummy = mk_Prop();
expr tmp = Pi_as_is(ctx, Pi(tmp_locals, dummy, m_p), m_p);

View file

@ -299,7 +299,7 @@ struct structure_cmd_fn {
for (expr const & parent : m_parents)
tmp_locals.push_back(mk_local(m_ngen.next(), parent));
expr_struct_set dep_set;
collected_locals dep_set;
for (expr const & v : include_vars) {
::lean::collect_locals(mlocal_type(v), dep_set);
dep_set.insert(v);
@ -307,7 +307,7 @@ struct structure_cmd_fn {
for (expr const & p : m_params)
::lean::collect_locals(mlocal_type(p), dep_set);
buffer<expr> ctx;
sort_locals(dep_set, m_p, ctx);
sort_locals(dep_set.get_collected(), m_p, ctx);
expr tmp = Pi_as_is(ctx, Pi(tmp_locals, m_type, m_p), m_p);
level_param_names new_ls;
@ -602,9 +602,9 @@ struct structure_cmd_fn {
return;
expr dummy = mk_Prop();
expr tmp = Pi(m_params, Pi(m_fields, dummy));
expr_struct_set local_set;
collected_locals local_set;
::lean::collect_locals(tmp, local_set);
sort_locals(local_set, m_p, locals);
sort_locals(local_set.get_collected(), m_p, locals);
}
/** \brief Add context locals as extra parameters */

View file

@ -74,7 +74,7 @@ name remove_root_prefix(name const & n) {
}
// Sort local names by order of occurrence, and copy the associated parameters to ps
void sort_locals(expr_struct_set const & locals, parser const & p, buffer<expr> & ps) {
void sort_locals(buffer<expr> const & locals, parser const & p, buffer<expr> & ps) {
for (expr const & l : locals) {
// we only copy the locals that are in p's local context
if (p.is_local_decl(l))
@ -104,9 +104,9 @@ levels collect_local_nonvar_levels(parser & p, level_param_names const & ls) {
return to_list(section_ls_buffer.begin(), section_ls_buffer.end());
}
// Version of collect_locals(expr const & e, expr_struct_set & ls) that ignores local constants occurring in
// Version of collect_locals(expr const & e, collected_locals & ls) that ignores local constants occurring in
// tactics.
static void collect_locals_ignoring_tactics(expr const & e, expr_struct_set & ls) {
static void collect_locals_ignoring_tactics(expr const & e, collected_locals & ls) {
if (!has_local(e))
return;
for_each(e, [&](expr const & e, unsigned) {
@ -122,7 +122,7 @@ static void collect_locals_ignoring_tactics(expr const & e, expr_struct_set & ls
// Collect local constants occurring in type and value, sort them, and store in ctx_ps
void collect_locals(expr const & type, expr const & value, parser const & p, buffer<expr> & ctx_ps) {
expr_struct_set ls;
collected_locals ls;
buffer<expr> include_vars;
p.get_include_variables(include_vars);
for (expr const & param : include_vars) {
@ -133,7 +133,7 @@ void collect_locals(expr const & type, expr const & value, parser const & p, buf
}
collect_locals_ignoring_tactics(type, ls);
collect_locals_ignoring_tactics(value, ls);
sort_locals(ls, p, ctx_ps);
sort_locals(ls.get_collected(), p, ctx_ps);
}
name_set collect_univ_params_ignoring_tactics(expr const & e, name_set const & ls) {
@ -177,10 +177,10 @@ levels remove_local_vars(parser const & p, levels const & ls) {
}
list<expr> locals_to_context(expr const & e, parser const & p) {
expr_struct_set ls;
collected_locals ls;
collect_locals_ignoring_tactics(e, ls);
buffer<expr> locals;
sort_locals(ls, p, locals);
sort_locals(ls.get_collected(), p, locals);
std::reverse(locals.begin(), locals.end());
return to_list(locals.begin(), locals.end());
}

View file

@ -38,7 +38,7 @@ levels collect_local_nonvar_levels(parser & p, level_param_names const & ls);
void collect_locals(expr const & type, expr const & value, parser const & p, buffer<expr> & ctx_ps);
name_set collect_univ_params_ignoring_tactics(expr const & e, name_set const & ls = name_set());
/** \brief Copy the local names to \c ps, then sort \c ps (using the order in which they were declared). */
void sort_locals(expr_struct_set const & locals, parser const & p, buffer<expr> & ps);
void sort_locals(buffer<expr> const & locals, parser const & p, buffer<expr> & ps);
/** \brief Remove from \c ps local constants that are tagged as variables. */
void remove_local_vars(parser const & p, buffer<expr> & ps);
/** \brief Remove from \c ls any universe level that is tagged as variable in \c p */

View file

@ -391,13 +391,17 @@ pair<bool, constraint_seq> type_checker::is_def_eq_types(expr const & t, expr co
/** \brief Return true iff \c e is a proposition */
pair<bool, constraint_seq> type_checker::is_prop(expr const & e) {
auto tcs = infer_type(e);
auto wtcs = whnf(tcs.first);
bool r = wtcs.first == mk_Prop();
if (r)
return mk_pair(true, tcs.second + wtcs.second);
else
if (m_env.impredicative()) {
auto tcs = infer_type(e);
auto wtcs = whnf(tcs.first);
bool r = wtcs.first == mk_Prop();
if (r)
return mk_pair(true, tcs.second + wtcs.second);
else
return mk_pair(false, constraint_seq());
} else {
return mk_pair(false, constraint_seq());
}
}
pair<expr, constraint_seq> type_checker::whnf(expr const & t) {

View file

@ -205,6 +205,12 @@ public:
cs = r.second + cs;
return r.first;
}
bool is_prop(expr const & t, constraint_seq & cs) {
auto r = is_prop(t);
if (r.first)
cs += r.second;
return r.first;
}
optional<expr> expand_macro(expr const & m);

View file

@ -49,7 +49,14 @@ level_param_names to_level_param_names(name_set const & ls) {
return r;
}
void collect_locals(expr const & e, expr_struct_set & ls, bool restricted) {
void collected_locals::insert(expr const & l) {
if (m_local_names.contains(mlocal_name(l)))
return;
m_local_names.insert(mlocal_name(l));
m_locals.push_back(l);
}
void collect_locals(expr const & e, collected_locals & ls, bool restricted) {
if (!has_local(e))
return;
for_each(e, [&](expr const & e, unsigned) {

View file

@ -15,7 +15,17 @@ name_set collect_univ_params(expr const & e, name_set const & ls = name_set());
\remark If restricted is true, then locals in meta-variable applications and local constants
are ignored.
*/
void collect_locals(expr const & e, expr_struct_set & ls, bool restricted = false);
class collected_locals {
name_set m_local_names;
buffer<expr> m_locals;
public:
void insert(expr const & l);
bool contains(expr const & l) const { return m_local_names.contains(mlocal_name(l)); }
buffer<expr> const & get_collected() const { return m_locals; }
bool empty() const { return m_locals.empty(); }
};
void collect_locals(expr const & e, collected_locals & ls, bool restricted = false);
level_param_names to_level_param_names(name_set const & ls);
/** \brief Return true iff \c [begin_locals, end_locals) contains \c local */

View file

@ -1254,7 +1254,7 @@ class rewrite_fn {
return result;
}
bool move_after(expr const & hyp, expr_struct_set const & hyps) {
bool move_after(expr const & hyp, buffer<expr> const & hyps) {
buffer<name> used_hyp_names;
for (auto const & p : hyps) {
used_hyp_names.push_back(mlocal_name(p));
@ -1279,9 +1279,9 @@ class rewrite_fn {
expr a, Heq, b; // Heq is a proof of a = b
std::tie(a, b, Heq) = *it;
// We must make sure that hyp occurs after all hypotheses in b
expr_struct_set b_hyps;
collected_locals b_hyps;
collect_locals(b, b_hyps);
if (!move_after(hyp, b_hyps))
if (!move_after(hyp, b_hyps.get_collected()))
return false;
bool has_dep_elim = inductive::has_dep_elim(m_env, get_eq_name());
unsigned vidx = has_dep_elim ? 1 : 0;
@ -1368,7 +1368,7 @@ class rewrite_fn {
location const & loc = info.get_location();
if (loc.is_goal_only())
return process_rewrite_goal(orig_elem, pattern, *loc.includes_goal());
expr_struct_set used_hyps;
collected_locals used_hyps;
collect_locals(elem, used_hyps, true);
// We collect hypotheses used in the rewrite step. They are not rewritten.
// That is, we don't use them to rewrite themselves.
@ -1377,7 +1377,7 @@ class rewrite_fn {
buffer<expr> hyps;
m_g.get_hyps(hyps);
for (expr const & h : hyps) {
if (used_hyps.find(h) != used_hyps.end())
if (used_hyps.contains(h))
continue; // skip used hypothesis
auto occ = loc.includes_hypothesis(local_pp_name(h));
if (!occ)

View file

@ -34,105 +34,109 @@ static void split_deps(buffer<expr> const & hyps, expr const & x, expr const & h
}
}
optional<proof_state> subst(environment const & env, name const & h_name, bool symm, proof_state const & s) {
goals const & gs = s.get_goals();
if (empty(gs))
return none_proof_state();
goal g = head(gs);
auto opt_h = g.find_hyp_from_internal_name(h_name);
if (!opt_h)
return none_proof_state();
expr const & h = opt_h->first;
expr lhs, rhs;
if (!is_eq(mlocal_type(h), lhs, rhs))
return none_proof_state();
name_generator ngen = s.get_ngen();
auto tc = mk_type_checker(env, ngen.mk_child());
if (symm)
std::swap(lhs, rhs);
if (!is_local(lhs))
return none_proof_state();
buffer<expr> hyps, deps, non_deps;
g.get_hyps(hyps);
bool depends_on_h = depends_on(g.get_type(), h);
split_deps(hyps, lhs, h, non_deps, deps, depends_on_h);
// revert dependencies
expr type = Pi(deps, g.get_type());
// substitute
bool has_dep_elim = inductive::has_dep_elim(env, get_eq_name());
bool use_dep_elim = has_dep_elim;
if (depends_on_h)
use_dep_elim = true;
expr motive, new_type;
new_type = instantiate(abstract_local(type, mlocal_name(lhs)), rhs);
if (use_dep_elim) {
new_type = instantiate(abstract_local(new_type, mlocal_name(h)), mk_refl(*tc, rhs));
if (symm) {
motive = Fun(lhs, Fun(h, type));
} else {
expr Heq = mk_local(ngen.next(), local_pp_name(h), mk_eq(*tc, rhs, lhs), binder_info());
motive = Fun(lhs, Fun(Heq, type));
}
} else {
motive = Fun(lhs, type);
}
buffer<expr> new_hyps;
buffer<expr> intros_hyps;
new_hyps.append(non_deps);
// reintroduce dependencies
expr new_goal_type = new_type;
for (expr const & d : deps) {
if (!is_pi(new_goal_type))
return none_proof_state();
expr new_h = mk_local(ngen.next(), local_pp_name(d), binding_domain(new_goal_type), binder_info());
new_hyps.push_back(new_h);
intros_hyps.push_back(new_h);
new_goal_type = instantiate(binding_body(new_goal_type), new_h);
}
// create new goal
expr new_metavar = mk_metavar(ngen.next(), Pi(new_hyps, new_goal_type));
expr new_meta_core = mk_app(new_metavar, non_deps);
expr new_meta = mk_app(new_meta_core, intros_hyps);
goal new_g(new_meta, new_goal_type);
// create eqrec term
substitution new_subst = s.get_subst();
expr major = symm ? h : mk_symm(*tc, h);
expr minor = new_meta_core;
expr A = tc->infer(lhs).first;
level l1 = sort_level(tc->ensure_type(new_type).first);
level l2 = sort_level(tc->ensure_type(A).first);
name eq_rec_name;
if (!has_dep_elim && use_dep_elim)
eq_rec_name = get_eq_drec_name();
else
eq_rec_name = get_eq_rec_name();
expr eqrec = mk_app({mk_constant(eq_rec_name, {l1, l2}), A, rhs, motive, minor, lhs, major});
if (use_dep_elim) {
try {
check_term(env, g.abstract(eqrec));
} catch (kernel_exception & ex) {
if (!s.report_failure())
return none_proof_state();
std::shared_ptr<kernel_exception> saved_ex(static_cast<kernel_exception*>(ex.clone()));
throw tactic_exception("rewrite step failed", none_expr(), s,
[=](formatter const & fmt) {
format r;
r += format("invalid 'subst' tactic, "
"produced type incorrect term, details: ");
r += saved_ex->pp(fmt);
r += line();
return r;
});
}
}
expr new_val = mk_app(eqrec, deps);
assign(new_subst, g, new_val);
lean_assert(new_subst.is_assigned(g.get_mvar()));
proof_state new_s(s, goals(new_g, tail(gs)), new_subst, ngen);
return some_proof_state(new_s);
}
tactic mk_subst_tactic_core(name const & h_name, bool symm) {
auto fn = [=](environment const & env, io_state const &, proof_state const & s) {
goals const & gs = s.get_goals();
if (empty(gs))
return none_proof_state();
goal g = head(gs);
auto opt_h = g.find_hyp_from_internal_name(h_name);
if (!opt_h)
return none_proof_state();
expr const & h = opt_h->first;
expr lhs, rhs;
if (!is_eq(mlocal_type(h), lhs, rhs))
return none_proof_state();
name_generator ngen = s.get_ngen();
auto tc = mk_type_checker(env, ngen.mk_child());
if (symm)
std::swap(lhs, rhs);
if (!is_local(lhs))
return none_proof_state();
buffer<expr> hyps, deps, non_deps;
g.get_hyps(hyps);
bool depends_on_h = depends_on(g.get_type(), h);
split_deps(hyps, lhs, h, non_deps, deps, depends_on_h);
// revert dependencies
expr type = Pi(deps, g.get_type());
// substitute
bool has_dep_elim = inductive::has_dep_elim(env, get_eq_name());
bool use_dep_elim = has_dep_elim;
if (depends_on_h)
use_dep_elim = true;
expr motive, new_type;
new_type = instantiate(abstract_local(type, mlocal_name(lhs)), rhs);
if (use_dep_elim) {
new_type = instantiate(abstract_local(new_type, mlocal_name(h)), mk_refl(*tc, rhs));
if (symm) {
motive = Fun(lhs, Fun(h, type));
} else {
expr Heq = mk_local(ngen.next(), local_pp_name(h), mk_eq(*tc, rhs, lhs), binder_info());
motive = Fun(lhs, Fun(Heq, type));
}
} else {
motive = Fun(lhs, type);
}
buffer<expr> new_hyps;
buffer<expr> intros_hyps;
new_hyps.append(non_deps);
// reintroduce dependencies
expr new_goal_type = new_type;
for (expr const & d : deps) {
if (!is_pi(new_goal_type))
return none_proof_state();
expr new_h = mk_local(ngen.next(), local_pp_name(d), binding_domain(new_goal_type), binder_info());
new_hyps.push_back(new_h);
intros_hyps.push_back(new_h);
new_goal_type = instantiate(binding_body(new_goal_type), new_h);
}
// create new goal
expr new_metavar = mk_metavar(ngen.next(), Pi(new_hyps, new_goal_type));
expr new_meta_core = mk_app(new_metavar, non_deps);
expr new_meta = mk_app(new_meta_core, intros_hyps);
goal new_g(new_meta, new_goal_type);
// create eqrec term
substitution new_subst = s.get_subst();
expr major = symm ? h : mk_symm(*tc, h);
expr minor = new_meta_core;
expr A = tc->infer(lhs).first;
level l1 = sort_level(tc->ensure_type(new_type).first);
level l2 = sort_level(tc->ensure_type(A).first);
name eq_rec_name;
if (!has_dep_elim && use_dep_elim)
eq_rec_name = get_eq_drec_name();
else
eq_rec_name = get_eq_rec_name();
expr eqrec = mk_app({mk_constant(eq_rec_name, {l1, l2}), A, rhs, motive, minor, lhs, major});
if (use_dep_elim) {
try {
check_term(env, g.abstract(eqrec));
} catch (kernel_exception & ex) {
if (!s.report_failure())
return none_proof_state();
std::shared_ptr<kernel_exception> saved_ex(static_cast<kernel_exception*>(ex.clone()));
throw tactic_exception("rewrite step failed", none_expr(), s,
[=](formatter const & fmt) {
format r;
r += format("invalid 'subst' tactic, "
"produced type incorrect term, details: ");
r += saved_ex->pp(fmt);
r += line();
return r;
});
}
}
expr new_val = mk_app(eqrec, deps);
assign(new_subst, g, new_val);
lean_assert(new_subst.is_assigned(g.get_mvar()));
proof_state new_s(s, goals(new_g, tail(gs)), new_subst, ngen);
return some_proof_state(new_s);
return subst(env, h_name, symm, s);
};
return tactic01(fn);
}

View file

@ -5,7 +5,9 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include "library/tactic/proof_state.h"
namespace lean {
optional<proof_state> subst(environment const & env, name const & h_name, bool symm, proof_state const & s);
void initialize_subst_tactic();
void finalize_subst_tactic();
}