From d0d3f9bb41de18f826b107e043acbf9df952e69c Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 5 Jun 2015 17:29:36 -0700 Subject: [PATCH] refactor(kernel,library,frontends/lean): add helper functions, and cleanup collect_locals --- src/frontends/lean/inductive_cmd.cpp | 6 +- src/frontends/lean/migrate_cmd.cpp | 4 +- src/frontends/lean/structure_cmd.cpp | 8 +- src/frontends/lean/util.cpp | 14 +- src/frontends/lean/util.h | 2 +- src/kernel/type_checker.cpp | 16 ++- src/kernel/type_checker.h | 6 + src/library/locals.cpp | 9 +- src/library/locals.h | 12 +- src/library/tactic/rewrite_tactic.cpp | 10 +- src/library/tactic/subst_tactic.cpp | 198 +++++++++++++------------- src/library/tactic/subst_tactic.h | 2 + 12 files changed, 160 insertions(+), 127 deletions(-) diff --git a/src/frontends/lean/inductive_cmd.cpp b/src/frontends/lean/inductive_cmd.cpp index ade6bb1b6..b1ff6bb77 100644 --- a/src/frontends/lean/inductive_cmd.cpp +++ b/src/frontends/lean/inductive_cmd.cpp @@ -342,7 +342,7 @@ struct inductive_cmd_fn { } /** \brief Collect local constants used in the inductive decls. */ - void collect_locals_core(buffer const & decls, expr_struct_set & ls) { + void collect_locals_core(buffer const & decls, collected_locals & ls) { buffer include_vars; m_p.get_include_variables(include_vars); for (expr const & param : include_vars) { @@ -365,11 +365,11 @@ struct inductive_cmd_fn { void collect_locals(buffer & decls, buffer & locals) { if (!m_p.has_locals()) return; - expr_struct_set local_set; + collected_locals local_set; collect_locals_core(decls, local_set); if (local_set.empty()) return; - sort_locals(local_set, m_p, locals); + sort_locals(local_set.get_collected(), m_p, locals); m_num_params += locals.size(); } diff --git a/src/frontends/lean/migrate_cmd.cpp b/src/frontends/lean/migrate_cmd.cpp index 53b8fd5e3..e8b148eab 100644 --- a/src/frontends/lean/migrate_cmd.cpp +++ b/src/frontends/lean/migrate_cmd.cpp @@ -357,7 +357,7 @@ struct migrate_cmd_fn { for (auto const & p : m_replacements) tmp_locals.push_back(mk_local(m_ngen.next(), mk_as_type(p.second))); - expr_struct_set dep_set; + collected_locals dep_set; for (expr const & v : include_vars) { ::lean::collect_locals(mlocal_type(v), dep_set); dep_set.insert(v); @@ -365,7 +365,7 @@ struct migrate_cmd_fn { for (expr const & p : m_params) ::lean::collect_locals(mlocal_type(p), dep_set); buffer ctx; - sort_locals(dep_set, m_p, ctx); + sort_locals(dep_set.get_collected(), m_p, ctx); expr dummy = mk_Prop(); expr tmp = Pi_as_is(ctx, Pi(tmp_locals, dummy, m_p), m_p); diff --git a/src/frontends/lean/structure_cmd.cpp b/src/frontends/lean/structure_cmd.cpp index 387aba92d..92f73ecd3 100644 --- a/src/frontends/lean/structure_cmd.cpp +++ b/src/frontends/lean/structure_cmd.cpp @@ -299,7 +299,7 @@ struct structure_cmd_fn { for (expr const & parent : m_parents) tmp_locals.push_back(mk_local(m_ngen.next(), parent)); - expr_struct_set dep_set; + collected_locals dep_set; for (expr const & v : include_vars) { ::lean::collect_locals(mlocal_type(v), dep_set); dep_set.insert(v); @@ -307,7 +307,7 @@ struct structure_cmd_fn { for (expr const & p : m_params) ::lean::collect_locals(mlocal_type(p), dep_set); buffer ctx; - sort_locals(dep_set, m_p, ctx); + sort_locals(dep_set.get_collected(), m_p, ctx); expr tmp = Pi_as_is(ctx, Pi(tmp_locals, m_type, m_p), m_p); level_param_names new_ls; @@ -602,9 +602,9 @@ struct structure_cmd_fn { return; expr dummy = mk_Prop(); expr tmp = Pi(m_params, Pi(m_fields, dummy)); - expr_struct_set local_set; + collected_locals local_set; ::lean::collect_locals(tmp, local_set); - sort_locals(local_set, m_p, locals); + sort_locals(local_set.get_collected(), m_p, locals); } /** \brief Add context locals as extra parameters */ diff --git a/src/frontends/lean/util.cpp b/src/frontends/lean/util.cpp index 31c888bfd..303733b16 100644 --- a/src/frontends/lean/util.cpp +++ b/src/frontends/lean/util.cpp @@ -74,7 +74,7 @@ name remove_root_prefix(name const & n) { } // Sort local names by order of occurrence, and copy the associated parameters to ps -void sort_locals(expr_struct_set const & locals, parser const & p, buffer & ps) { +void sort_locals(buffer const & locals, parser const & p, buffer & ps) { for (expr const & l : locals) { // we only copy the locals that are in p's local context if (p.is_local_decl(l)) @@ -104,9 +104,9 @@ levels collect_local_nonvar_levels(parser & p, level_param_names const & ls) { return to_list(section_ls_buffer.begin(), section_ls_buffer.end()); } -// Version of collect_locals(expr const & e, expr_struct_set & ls) that ignores local constants occurring in +// Version of collect_locals(expr const & e, collected_locals & ls) that ignores local constants occurring in // tactics. -static void collect_locals_ignoring_tactics(expr const & e, expr_struct_set & ls) { +static void collect_locals_ignoring_tactics(expr const & e, collected_locals & ls) { if (!has_local(e)) return; for_each(e, [&](expr const & e, unsigned) { @@ -122,7 +122,7 @@ static void collect_locals_ignoring_tactics(expr const & e, expr_struct_set & ls // Collect local constants occurring in type and value, sort them, and store in ctx_ps void collect_locals(expr const & type, expr const & value, parser const & p, buffer & ctx_ps) { - expr_struct_set ls; + collected_locals ls; buffer include_vars; p.get_include_variables(include_vars); for (expr const & param : include_vars) { @@ -133,7 +133,7 @@ void collect_locals(expr const & type, expr const & value, parser const & p, buf } collect_locals_ignoring_tactics(type, ls); collect_locals_ignoring_tactics(value, ls); - sort_locals(ls, p, ctx_ps); + sort_locals(ls.get_collected(), p, ctx_ps); } name_set collect_univ_params_ignoring_tactics(expr const & e, name_set const & ls) { @@ -177,10 +177,10 @@ levels remove_local_vars(parser const & p, levels const & ls) { } list locals_to_context(expr const & e, parser const & p) { - expr_struct_set ls; + collected_locals ls; collect_locals_ignoring_tactics(e, ls); buffer locals; - sort_locals(ls, p, locals); + sort_locals(ls.get_collected(), p, locals); std::reverse(locals.begin(), locals.end()); return to_list(locals.begin(), locals.end()); } diff --git a/src/frontends/lean/util.h b/src/frontends/lean/util.h index 41a3e54ac..f87078a93 100644 --- a/src/frontends/lean/util.h +++ b/src/frontends/lean/util.h @@ -38,7 +38,7 @@ levels collect_local_nonvar_levels(parser & p, level_param_names const & ls); void collect_locals(expr const & type, expr const & value, parser const & p, buffer & ctx_ps); name_set collect_univ_params_ignoring_tactics(expr const & e, name_set const & ls = name_set()); /** \brief Copy the local names to \c ps, then sort \c ps (using the order in which they were declared). */ -void sort_locals(expr_struct_set const & locals, parser const & p, buffer & ps); +void sort_locals(buffer const & locals, parser const & p, buffer & ps); /** \brief Remove from \c ps local constants that are tagged as variables. */ void remove_local_vars(parser const & p, buffer & ps); /** \brief Remove from \c ls any universe level that is tagged as variable in \c p */ diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index e60ceb2d6..6482515b1 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -391,13 +391,17 @@ pair type_checker::is_def_eq_types(expr const & t, expr co /** \brief Return true iff \c e is a proposition */ pair type_checker::is_prop(expr const & e) { - auto tcs = infer_type(e); - auto wtcs = whnf(tcs.first); - bool r = wtcs.first == mk_Prop(); - if (r) - return mk_pair(true, tcs.second + wtcs.second); - else + if (m_env.impredicative()) { + auto tcs = infer_type(e); + auto wtcs = whnf(tcs.first); + bool r = wtcs.first == mk_Prop(); + if (r) + return mk_pair(true, tcs.second + wtcs.second); + else + return mk_pair(false, constraint_seq()); + } else { return mk_pair(false, constraint_seq()); + } } pair type_checker::whnf(expr const & t) { diff --git a/src/kernel/type_checker.h b/src/kernel/type_checker.h index d183a84eb..a8c09bb78 100644 --- a/src/kernel/type_checker.h +++ b/src/kernel/type_checker.h @@ -205,6 +205,12 @@ public: cs = r.second + cs; return r.first; } + bool is_prop(expr const & t, constraint_seq & cs) { + auto r = is_prop(t); + if (r.first) + cs += r.second; + return r.first; + } optional expand_macro(expr const & m); diff --git a/src/library/locals.cpp b/src/library/locals.cpp index 1780bb4c4..e390aeaae 100644 --- a/src/library/locals.cpp +++ b/src/library/locals.cpp @@ -49,7 +49,14 @@ level_param_names to_level_param_names(name_set const & ls) { return r; } -void collect_locals(expr const & e, expr_struct_set & ls, bool restricted) { +void collected_locals::insert(expr const & l) { + if (m_local_names.contains(mlocal_name(l))) + return; + m_local_names.insert(mlocal_name(l)); + m_locals.push_back(l); +} + +void collect_locals(expr const & e, collected_locals & ls, bool restricted) { if (!has_local(e)) return; for_each(e, [&](expr const & e, unsigned) { diff --git a/src/library/locals.h b/src/library/locals.h index 56ad9e1a0..ad48847bd 100644 --- a/src/library/locals.h +++ b/src/library/locals.h @@ -15,7 +15,17 @@ name_set collect_univ_params(expr const & e, name_set const & ls = name_set()); \remark If restricted is true, then locals in meta-variable applications and local constants are ignored. */ -void collect_locals(expr const & e, expr_struct_set & ls, bool restricted = false); +class collected_locals { + name_set m_local_names; + buffer m_locals; +public: + void insert(expr const & l); + bool contains(expr const & l) const { return m_local_names.contains(mlocal_name(l)); } + buffer const & get_collected() const { return m_locals; } + bool empty() const { return m_locals.empty(); } +}; + +void collect_locals(expr const & e, collected_locals & ls, bool restricted = false); level_param_names to_level_param_names(name_set const & ls); /** \brief Return true iff \c [begin_locals, end_locals) contains \c local */ diff --git a/src/library/tactic/rewrite_tactic.cpp b/src/library/tactic/rewrite_tactic.cpp index c2076b0ea..d48edd45a 100644 --- a/src/library/tactic/rewrite_tactic.cpp +++ b/src/library/tactic/rewrite_tactic.cpp @@ -1254,7 +1254,7 @@ class rewrite_fn { return result; } - bool move_after(expr const & hyp, expr_struct_set const & hyps) { + bool move_after(expr const & hyp, buffer const & hyps) { buffer used_hyp_names; for (auto const & p : hyps) { used_hyp_names.push_back(mlocal_name(p)); @@ -1279,9 +1279,9 @@ class rewrite_fn { expr a, Heq, b; // Heq is a proof of a = b std::tie(a, b, Heq) = *it; // We must make sure that hyp occurs after all hypotheses in b - expr_struct_set b_hyps; + collected_locals b_hyps; collect_locals(b, b_hyps); - if (!move_after(hyp, b_hyps)) + if (!move_after(hyp, b_hyps.get_collected())) return false; bool has_dep_elim = inductive::has_dep_elim(m_env, get_eq_name()); unsigned vidx = has_dep_elim ? 1 : 0; @@ -1368,7 +1368,7 @@ class rewrite_fn { location const & loc = info.get_location(); if (loc.is_goal_only()) return process_rewrite_goal(orig_elem, pattern, *loc.includes_goal()); - expr_struct_set used_hyps; + collected_locals used_hyps; collect_locals(elem, used_hyps, true); // We collect hypotheses used in the rewrite step. They are not rewritten. // That is, we don't use them to rewrite themselves. @@ -1377,7 +1377,7 @@ class rewrite_fn { buffer hyps; m_g.get_hyps(hyps); for (expr const & h : hyps) { - if (used_hyps.find(h) != used_hyps.end()) + if (used_hyps.contains(h)) continue; // skip used hypothesis auto occ = loc.includes_hypothesis(local_pp_name(h)); if (!occ) diff --git a/src/library/tactic/subst_tactic.cpp b/src/library/tactic/subst_tactic.cpp index e5dcd25e7..b35b59e8d 100644 --- a/src/library/tactic/subst_tactic.cpp +++ b/src/library/tactic/subst_tactic.cpp @@ -34,105 +34,109 @@ static void split_deps(buffer const & hyps, expr const & x, expr const & h } } +optional subst(environment const & env, name const & h_name, bool symm, proof_state const & s) { + goals const & gs = s.get_goals(); + if (empty(gs)) + return none_proof_state(); + goal g = head(gs); + auto opt_h = g.find_hyp_from_internal_name(h_name); + if (!opt_h) + return none_proof_state(); + expr const & h = opt_h->first; + expr lhs, rhs; + if (!is_eq(mlocal_type(h), lhs, rhs)) + return none_proof_state(); + name_generator ngen = s.get_ngen(); + auto tc = mk_type_checker(env, ngen.mk_child()); + if (symm) + std::swap(lhs, rhs); + if (!is_local(lhs)) + return none_proof_state(); + buffer hyps, deps, non_deps; + g.get_hyps(hyps); + bool depends_on_h = depends_on(g.get_type(), h); + split_deps(hyps, lhs, h, non_deps, deps, depends_on_h); + // revert dependencies + expr type = Pi(deps, g.get_type()); + // substitute + bool has_dep_elim = inductive::has_dep_elim(env, get_eq_name()); + bool use_dep_elim = has_dep_elim; + if (depends_on_h) + use_dep_elim = true; + expr motive, new_type; + new_type = instantiate(abstract_local(type, mlocal_name(lhs)), rhs); + if (use_dep_elim) { + new_type = instantiate(abstract_local(new_type, mlocal_name(h)), mk_refl(*tc, rhs)); + if (symm) { + motive = Fun(lhs, Fun(h, type)); + } else { + expr Heq = mk_local(ngen.next(), local_pp_name(h), mk_eq(*tc, rhs, lhs), binder_info()); + motive = Fun(lhs, Fun(Heq, type)); + } + } else { + motive = Fun(lhs, type); + } + buffer new_hyps; + buffer intros_hyps; + new_hyps.append(non_deps); + + // reintroduce dependencies + expr new_goal_type = new_type; + for (expr const & d : deps) { + if (!is_pi(new_goal_type)) + return none_proof_state(); + expr new_h = mk_local(ngen.next(), local_pp_name(d), binding_domain(new_goal_type), binder_info()); + new_hyps.push_back(new_h); + intros_hyps.push_back(new_h); + new_goal_type = instantiate(binding_body(new_goal_type), new_h); + } + + // create new goal + expr new_metavar = mk_metavar(ngen.next(), Pi(new_hyps, new_goal_type)); + expr new_meta_core = mk_app(new_metavar, non_deps); + expr new_meta = mk_app(new_meta_core, intros_hyps); + goal new_g(new_meta, new_goal_type); + // create eqrec term + substitution new_subst = s.get_subst(); + expr major = symm ? h : mk_symm(*tc, h); + expr minor = new_meta_core; + expr A = tc->infer(lhs).first; + level l1 = sort_level(tc->ensure_type(new_type).first); + level l2 = sort_level(tc->ensure_type(A).first); + name eq_rec_name; + if (!has_dep_elim && use_dep_elim) + eq_rec_name = get_eq_drec_name(); + else + eq_rec_name = get_eq_rec_name(); + expr eqrec = mk_app({mk_constant(eq_rec_name, {l1, l2}), A, rhs, motive, minor, lhs, major}); + if (use_dep_elim) { + try { + check_term(env, g.abstract(eqrec)); + } catch (kernel_exception & ex) { + if (!s.report_failure()) + return none_proof_state(); + std::shared_ptr saved_ex(static_cast(ex.clone())); + throw tactic_exception("rewrite step failed", none_expr(), s, + [=](formatter const & fmt) { + format r; + r += format("invalid 'subst' tactic, " + "produced type incorrect term, details: "); + r += saved_ex->pp(fmt); + r += line(); + return r; + }); + } + } + expr new_val = mk_app(eqrec, deps); + assign(new_subst, g, new_val); + lean_assert(new_subst.is_assigned(g.get_mvar())); + proof_state new_s(s, goals(new_g, tail(gs)), new_subst, ngen); + return some_proof_state(new_s); +} + tactic mk_subst_tactic_core(name const & h_name, bool symm) { auto fn = [=](environment const & env, io_state const &, proof_state const & s) { - goals const & gs = s.get_goals(); - if (empty(gs)) - return none_proof_state(); - goal g = head(gs); - auto opt_h = g.find_hyp_from_internal_name(h_name); - if (!opt_h) - return none_proof_state(); - expr const & h = opt_h->first; - expr lhs, rhs; - if (!is_eq(mlocal_type(h), lhs, rhs)) - return none_proof_state(); - name_generator ngen = s.get_ngen(); - auto tc = mk_type_checker(env, ngen.mk_child()); - if (symm) - std::swap(lhs, rhs); - if (!is_local(lhs)) - return none_proof_state(); - buffer hyps, deps, non_deps; - g.get_hyps(hyps); - bool depends_on_h = depends_on(g.get_type(), h); - split_deps(hyps, lhs, h, non_deps, deps, depends_on_h); - // revert dependencies - expr type = Pi(deps, g.get_type()); - // substitute - bool has_dep_elim = inductive::has_dep_elim(env, get_eq_name()); - bool use_dep_elim = has_dep_elim; - if (depends_on_h) - use_dep_elim = true; - expr motive, new_type; - new_type = instantiate(abstract_local(type, mlocal_name(lhs)), rhs); - if (use_dep_elim) { - new_type = instantiate(abstract_local(new_type, mlocal_name(h)), mk_refl(*tc, rhs)); - if (symm) { - motive = Fun(lhs, Fun(h, type)); - } else { - expr Heq = mk_local(ngen.next(), local_pp_name(h), mk_eq(*tc, rhs, lhs), binder_info()); - motive = Fun(lhs, Fun(Heq, type)); - } - } else { - motive = Fun(lhs, type); - } - buffer new_hyps; - buffer intros_hyps; - new_hyps.append(non_deps); - - // reintroduce dependencies - expr new_goal_type = new_type; - for (expr const & d : deps) { - if (!is_pi(new_goal_type)) - return none_proof_state(); - expr new_h = mk_local(ngen.next(), local_pp_name(d), binding_domain(new_goal_type), binder_info()); - new_hyps.push_back(new_h); - intros_hyps.push_back(new_h); - new_goal_type = instantiate(binding_body(new_goal_type), new_h); - } - - // create new goal - expr new_metavar = mk_metavar(ngen.next(), Pi(new_hyps, new_goal_type)); - expr new_meta_core = mk_app(new_metavar, non_deps); - expr new_meta = mk_app(new_meta_core, intros_hyps); - goal new_g(new_meta, new_goal_type); - // create eqrec term - substitution new_subst = s.get_subst(); - expr major = symm ? h : mk_symm(*tc, h); - expr minor = new_meta_core; - expr A = tc->infer(lhs).first; - level l1 = sort_level(tc->ensure_type(new_type).first); - level l2 = sort_level(tc->ensure_type(A).first); - name eq_rec_name; - if (!has_dep_elim && use_dep_elim) - eq_rec_name = get_eq_drec_name(); - else - eq_rec_name = get_eq_rec_name(); - expr eqrec = mk_app({mk_constant(eq_rec_name, {l1, l2}), A, rhs, motive, minor, lhs, major}); - if (use_dep_elim) { - try { - check_term(env, g.abstract(eqrec)); - } catch (kernel_exception & ex) { - if (!s.report_failure()) - return none_proof_state(); - std::shared_ptr saved_ex(static_cast(ex.clone())); - throw tactic_exception("rewrite step failed", none_expr(), s, - [=](formatter const & fmt) { - format r; - r += format("invalid 'subst' tactic, " - "produced type incorrect term, details: "); - r += saved_ex->pp(fmt); - r += line(); - return r; - }); - } - } - expr new_val = mk_app(eqrec, deps); - assign(new_subst, g, new_val); - lean_assert(new_subst.is_assigned(g.get_mvar())); - proof_state new_s(s, goals(new_g, tail(gs)), new_subst, ngen); - return some_proof_state(new_s); + return subst(env, h_name, symm, s); }; return tactic01(fn); } diff --git a/src/library/tactic/subst_tactic.h b/src/library/tactic/subst_tactic.h index 91e191445..bca1400bd 100644 --- a/src/library/tactic/subst_tactic.h +++ b/src/library/tactic/subst_tactic.h @@ -5,7 +5,9 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #pragma once +#include "library/tactic/proof_state.h" namespace lean { +optional subst(environment const & env, name const & h_name, bool symm, proof_state const & s); void initialize_subst_tactic(); void finalize_subst_tactic(); }