diff --git a/src/library/tactic/location.cpp b/src/library/tactic/location.cpp index 66b7e078e..8dfd0a0ee 100644 --- a/src/library/tactic/location.cpp +++ b/src/library/tactic/location.cpp @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ +#include "kernel/replace_fn.h" #include "library/kernel_serializer.h" #include "library/tactic/location.h" @@ -110,4 +111,16 @@ deserializer & operator>>(deserializer & d, location & loc) { loc.m_hyps = to_list(tmp); return d; } + +expr replace_occurrences(expr const & e, expr const & t, occurrence const & occ, unsigned idx) { + unsigned occ_idx = 0; + return replace(e, [&](expr const & e, unsigned offset) { + if (e == t) { + occ_idx++; + if (occ.contains(occ_idx)) + return some_expr(mk_var(offset+idx)); + } + return none_expr(); + }); +} } diff --git a/src/library/tactic/location.h b/src/library/tactic/location.h index 7ea4e8cb7..9366c607b 100644 --- a/src/library/tactic/location.h +++ b/src/library/tactic/location.h @@ -49,6 +49,11 @@ public: friend deserializer & operator>>(deserializer & d, occurrence & o); }; +/** \brief Replace occurrences of \c t in \c e with the free variable #vidx. + The j-th occurrence is replaced iff occ.contains(j) +*/ +expr replace_occurrences(expr const & e, expr const & t, occurrence const & occ, unsigned vidx); + class location { public: enum kind { Everywhere, GoalOnly, AllHypotheses, Hypotheses, GoalHypotheses }; diff --git a/src/library/tactic/rewrite_tactic.cpp b/src/library/tactic/rewrite_tactic.cpp index 5be24ec37..c59777dea 100644 --- a/src/library/tactic/rewrite_tactic.cpp +++ b/src/library/tactic/rewrite_tactic.cpp @@ -10,6 +10,7 @@ Author: Leonardo de Moura #include "kernel/instantiate.h" #include "kernel/replace_fn.h" #include "kernel/for_each_fn.h" +#include "kernel/inductive/inductive.h" #include "library/kernel_serializer.h" #include "library/reducible.h" #include "library/util.h" @@ -18,6 +19,7 @@ Author: Leonardo de Moura #include "library/local_context.h" #include "library/unifier.h" #include "library/util.h" +#include "library/constants.h" #include "library/generic_exception.h" #include "library/tactic/rewrite_tactic.h" #include "library/tactic/expr_to_tactic.h" @@ -498,6 +500,7 @@ class rewrite_fn { // rule, new_t typedef optional> unify_result; + // When successful, the result is the pair (H, new_t) where (H : new_t = t) unify_result unify_target(expr const & t, expr const & pre_elem) { try { expr rule = get_rewrite_rule(pre_elem); @@ -545,9 +548,9 @@ class rewrite_fn { expr lhs = app_arg(app_fn(rule_type)); expr rhs = app_arg(rule_type); if (symm) { - rule = mk_symm(*m_tc, rule); return unify_result(rule, lhs); } else { + rule = mk_symm(*m_tc, rule); return unify_result(rule, rhs); } } @@ -555,7 +558,7 @@ class rewrite_fn { return unify_result(); } - // target, rule, new_target : represents the rewrite rule : target = new_target + // target, new_target, H : represents the rewrite H : target = new_target typedef optional> find_result; // Search for \c pattern in \c e. If \c t is a match, then try to unify the type of the rule @@ -575,7 +578,7 @@ class rewrite_fn { reset_subst(); if (r) { if (auto p = unify_target(t, pre_elem)) { - result = std::make_tuple(t, p->first, p->second); + result = std::make_tuple(t, p->second, p->first); return false; } } @@ -585,73 +588,112 @@ class rewrite_fn { return result; } - bool process_rewrite_hypothesis(expr const & hyp, expr const & elem, expr const & pre_elem, expr const & pattern) { + /** Given (a, b, P[a], Heq : b = a, occ), return (P[b], M : P[b], H : P[a]) + where M is a metavariable application of a fresh metavariable, and H is a witness (based on M) for P[a]. + + \remark occ is used to select which occurrences of a in P[a] will be replaced with b + + \remark the witness \c H is used using eq.rec + */ + std::tuple apply_rewrite(expr const & a, expr const & b, expr const & Pa, expr const & Heq, occurrence const & occ) { + bool has_dep_elim = inductive::has_dep_elim(m_env, get_eq_name()); + unsigned vidx = has_dep_elim ? 1 : 0; + expr Px = replace_occurrences(Pa, a, occ, vidx); + expr Pb = instantiate(Px, vidx, b); + expr A = m_tc->infer(a).first; + level l1 = sort_level(m_tc->ensure_type(Pa).first); + level l2 = sort_level(m_tc->ensure_type(A).first); + expr M = m_g.mk_meta(m_ngen.next(), Pb); + expr H; + if (has_dep_elim) { + expr Haeqx = mk_app(mk_constant(get_eq_name(), {l1}), A, b, mk_var(0)); + expr P = mk_lambda("x", A, mk_lambda("H", Haeqx, Px)); + H = mk_app({mk_constant(get_eq_rec_name(), {l1, l2}), A, b, P, M, a, Heq}); + } else { + H = mk_app({mk_constant(get_eq_rec_name(), {l1, l2}), A, b, mk_lambda("x", A, Px), M, a, Heq}); + } + return std::make_tuple(Pb, M, H); + } + + bool process_rewrite_hypothesis(expr const & hyp, expr const & pre_elem, expr const & pattern, occurrence const & occ) { // TODO(Leo) return false; } - bool process_rewrite_goal(expr const & elem, expr const & pre_elem, expr const & pattern) { - expr goal_type = m_g.get_type(); - if (auto it = find_target(goal_type, pattern, pre_elem)) { - regular(m_env, m_ios) << "FOUND\n" << std::get<0>(*it) << "\n==\n" << std::get<1>(*it) << "\n==>\n" << std::get<2>(*it) << "\n"; - // TODO(Leo) + bool process_rewrite_goal(expr const & pre_elem, expr const & pattern, occurrence const & occ) { + expr Pa = m_g.get_type(); + if (auto it = find_target(Pa, pattern, pre_elem)) { + expr a, Heq, b; + std::tie(a, b, Heq) = *it; + expr Pb, M, H; + std::tie(Pb, M, H) = apply_rewrite(a, b, Pa, Heq, occ); + goal new_g(M, Pb); + expr val = m_g.abstract(H); + m_subst.assign(m_g.get_name(), val); + update_goal(new_g); + // regular(m_env, m_ios) << "FOUND\n" << a << "\n==>\n" << b << "\nWITH\n" << Heq << "\n"; + // regular(m_env, m_ios) << H << "\n"; + return true; } return false; } - bool process_rewrite_single_step(expr const & elem, expr const & pre_elem, expr const & pattern) { + bool process_rewrite_single_step(expr const & pre_elem, expr const & pattern) { check_system("rewrite tactic"); - rewrite_info const & info = get_rewrite_info(elem); + rewrite_info const & info = get_rewrite_info(pre_elem); location const & loc = info.get_location(); if (loc.is_goal_only()) - return process_rewrite_goal(elem, pre_elem, pattern); + return process_rewrite_goal(pre_elem, pattern, *loc.includes_goal()); bool progress = false; buffer hyps; m_g.get_hyps(hyps); for (expr const & h : hyps) { - if (!loc.includes_hypothesis(local_pp_name(h))) + auto occ = loc.includes_hypothesis(local_pp_name(h)); + if (!occ) continue; - if (process_rewrite_hypothesis(h, elem, pre_elem, pattern)) + if (process_rewrite_hypothesis(h, pre_elem, pattern, *occ)) + progress = true; + } + if (auto occ = loc.includes_goal()) { + if (process_rewrite_goal(pre_elem, pattern, *occ)) progress = true; } - if (loc.includes_goal() && process_rewrite_goal(elem, pre_elem, pattern)) - progress = true; return progress; } bool process_rewrite_step(expr const & elem, expr const & pre_elem) { lean_assert(is_rewrite_step(elem)); expr pattern = get_pattern(elem); - regular(m_env, m_ios) << "pattern: " << pattern << "\n"; + // regular(m_env, m_ios) << "pattern: " << pattern << "\n"; rewrite_info const & info = get_rewrite_info(elem); unsigned num; switch (info.get_multiplicity()) { case rewrite_info::Once: - return process_rewrite_single_step(elem, pre_elem, pattern); + return process_rewrite_single_step(pre_elem, pattern); case rewrite_info::AtMostN: num = info.num(); for (unsigned i = 0; i < num; i++) { - if (!process_rewrite_single_step(elem, pre_elem, pattern)) + if (!process_rewrite_single_step(pre_elem, pattern)) return true; } return true; case rewrite_info::ExactlyN: num = info.num(); for (unsigned i = 0; i < num; i++) { - if (!process_rewrite_single_step(elem, pre_elem, pattern)) + if (!process_rewrite_single_step(pre_elem, pattern)) return false; } return true; case rewrite_info::ZeroOrMore: while (true) { - if (!process_rewrite_single_step(elem, pre_elem, pattern)) + if (!process_rewrite_single_step(pre_elem, pattern)) return true; } case rewrite_info::OneOrMore: - if (!process_rewrite_single_step(elem, pre_elem, pattern)) + if (!process_rewrite_single_step(pre_elem, pattern)) return false; while (true) { - if (!process_rewrite_single_step(elem, pre_elem, pattern)) + if (!process_rewrite_single_step(pre_elem, pattern)) return true; } } @@ -680,7 +722,6 @@ public: } proof_state_seq operator()(buffer const & elems) { - std::cout << "rewrite_tactic\n"; buffer new_elems; elaborate_elems(elems, new_elems); @@ -692,7 +733,9 @@ public: } } - return proof_state_seq(m_ps); + goals new_gs = cons(m_g, tail(m_ps.get_goals())); + proof_state new_ps(m_ps, new_gs, m_subst, m_ngen); + return proof_state_seq(new_ps); } };