feat(library/tactic/rewrite_tactic): rewrite goal

This commit is contained in:
Leonardo de Moura 2015-02-04 15:17:58 -08:00
parent 599de0271b
commit e5381679d6
3 changed files with 86 additions and 25 deletions

View file

@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura Author: Leonardo de Moura
*/ */
#include "kernel/replace_fn.h"
#include "library/kernel_serializer.h" #include "library/kernel_serializer.h"
#include "library/tactic/location.h" #include "library/tactic/location.h"
@ -110,4 +111,16 @@ deserializer & operator>>(deserializer & d, location & loc) {
loc.m_hyps = to_list(tmp); loc.m_hyps = to_list(tmp);
return d; return d;
} }
expr replace_occurrences(expr const & e, expr const & t, occurrence const & occ, unsigned idx) {
unsigned occ_idx = 0;
return replace(e, [&](expr const & e, unsigned offset) {
if (e == t) {
occ_idx++;
if (occ.contains(occ_idx))
return some_expr(mk_var(offset+idx));
}
return none_expr();
});
}
} }

View file

@ -49,6 +49,11 @@ public:
friend deserializer & operator>>(deserializer & d, occurrence & o); friend deserializer & operator>>(deserializer & d, occurrence & o);
}; };
/** \brief Replace occurrences of \c t in \c e with the free variable #vidx.
The j-th occurrence is replaced iff occ.contains(j)
*/
expr replace_occurrences(expr const & e, expr const & t, occurrence const & occ, unsigned vidx);
class location { class location {
public: public:
enum kind { Everywhere, GoalOnly, AllHypotheses, Hypotheses, GoalHypotheses }; enum kind { Everywhere, GoalOnly, AllHypotheses, Hypotheses, GoalHypotheses };

View file

@ -10,6 +10,7 @@ Author: Leonardo de Moura
#include "kernel/instantiate.h" #include "kernel/instantiate.h"
#include "kernel/replace_fn.h" #include "kernel/replace_fn.h"
#include "kernel/for_each_fn.h" #include "kernel/for_each_fn.h"
#include "kernel/inductive/inductive.h"
#include "library/kernel_serializer.h" #include "library/kernel_serializer.h"
#include "library/reducible.h" #include "library/reducible.h"
#include "library/util.h" #include "library/util.h"
@ -18,6 +19,7 @@ Author: Leonardo de Moura
#include "library/local_context.h" #include "library/local_context.h"
#include "library/unifier.h" #include "library/unifier.h"
#include "library/util.h" #include "library/util.h"
#include "library/constants.h"
#include "library/generic_exception.h" #include "library/generic_exception.h"
#include "library/tactic/rewrite_tactic.h" #include "library/tactic/rewrite_tactic.h"
#include "library/tactic/expr_to_tactic.h" #include "library/tactic/expr_to_tactic.h"
@ -498,6 +500,7 @@ class rewrite_fn {
// rule, new_t // rule, new_t
typedef optional<pair<expr, expr>> unify_result; typedef optional<pair<expr, expr>> unify_result;
// When successful, the result is the pair (H, new_t) where (H : new_t = t)
unify_result unify_target(expr const & t, expr const & pre_elem) { unify_result unify_target(expr const & t, expr const & pre_elem) {
try { try {
expr rule = get_rewrite_rule(pre_elem); expr rule = get_rewrite_rule(pre_elem);
@ -545,9 +548,9 @@ class rewrite_fn {
expr lhs = app_arg(app_fn(rule_type)); expr lhs = app_arg(app_fn(rule_type));
expr rhs = app_arg(rule_type); expr rhs = app_arg(rule_type);
if (symm) { if (symm) {
rule = mk_symm(*m_tc, rule);
return unify_result(rule, lhs); return unify_result(rule, lhs);
} else { } else {
rule = mk_symm(*m_tc, rule);
return unify_result(rule, rhs); return unify_result(rule, rhs);
} }
} }
@ -555,7 +558,7 @@ class rewrite_fn {
return unify_result(); return unify_result();
} }
// target, rule, new_target : represents the rewrite rule : target = new_target // target, new_target, H : represents the rewrite H : target = new_target
typedef optional<std::tuple<expr, expr, expr>> find_result; typedef optional<std::tuple<expr, expr, expr>> find_result;
// Search for \c pattern in \c e. If \c t is a match, then try to unify the type of the rule // Search for \c pattern in \c e. If \c t is a match, then try to unify the type of the rule
@ -575,7 +578,7 @@ class rewrite_fn {
reset_subst(); reset_subst();
if (r) { if (r) {
if (auto p = unify_target(t, pre_elem)) { if (auto p = unify_target(t, pre_elem)) {
result = std::make_tuple(t, p->first, p->second); result = std::make_tuple(t, p->second, p->first);
return false; return false;
} }
} }
@ -585,73 +588,112 @@ class rewrite_fn {
return result; return result;
} }
bool process_rewrite_hypothesis(expr const & hyp, expr const & elem, expr const & pre_elem, expr const & pattern) { /** Given (a, b, P[a], Heq : b = a, occ), return (P[b], M : P[b], H : P[a])
where M is a metavariable application of a fresh metavariable, and H is a witness (based on M) for P[a].
\remark occ is used to select which occurrences of a in P[a] will be replaced with b
\remark the witness \c H is used using eq.rec
*/
std::tuple<expr, expr, expr> apply_rewrite(expr const & a, expr const & b, expr const & Pa, expr const & Heq, occurrence const & occ) {
bool has_dep_elim = inductive::has_dep_elim(m_env, get_eq_name());
unsigned vidx = has_dep_elim ? 1 : 0;
expr Px = replace_occurrences(Pa, a, occ, vidx);
expr Pb = instantiate(Px, vidx, b);
expr A = m_tc->infer(a).first;
level l1 = sort_level(m_tc->ensure_type(Pa).first);
level l2 = sort_level(m_tc->ensure_type(A).first);
expr M = m_g.mk_meta(m_ngen.next(), Pb);
expr H;
if (has_dep_elim) {
expr Haeqx = mk_app(mk_constant(get_eq_name(), {l1}), A, b, mk_var(0));
expr P = mk_lambda("x", A, mk_lambda("H", Haeqx, Px));
H = mk_app({mk_constant(get_eq_rec_name(), {l1, l2}), A, b, P, M, a, Heq});
} else {
H = mk_app({mk_constant(get_eq_rec_name(), {l1, l2}), A, b, mk_lambda("x", A, Px), M, a, Heq});
}
return std::make_tuple(Pb, M, H);
}
bool process_rewrite_hypothesis(expr const & hyp, expr const & pre_elem, expr const & pattern, occurrence const & occ) {
// TODO(Leo) // TODO(Leo)
return false; return false;
} }
bool process_rewrite_goal(expr const & elem, expr const & pre_elem, expr const & pattern) { bool process_rewrite_goal(expr const & pre_elem, expr const & pattern, occurrence const & occ) {
expr goal_type = m_g.get_type(); expr Pa = m_g.get_type();
if (auto it = find_target(goal_type, pattern, pre_elem)) { if (auto it = find_target(Pa, pattern, pre_elem)) {
regular(m_env, m_ios) << "FOUND\n" << std::get<0>(*it) << "\n==\n" << std::get<1>(*it) << "\n==>\n" << std::get<2>(*it) << "\n"; expr a, Heq, b;
// TODO(Leo) std::tie(a, b, Heq) = *it;
expr Pb, M, H;
std::tie(Pb, M, H) = apply_rewrite(a, b, Pa, Heq, occ);
goal new_g(M, Pb);
expr val = m_g.abstract(H);
m_subst.assign(m_g.get_name(), val);
update_goal(new_g);
// regular(m_env, m_ios) << "FOUND\n" << a << "\n==>\n" << b << "\nWITH\n" << Heq << "\n";
// regular(m_env, m_ios) << H << "\n";
return true;
} }
return false; return false;
} }
bool process_rewrite_single_step(expr const & elem, expr const & pre_elem, expr const & pattern) { bool process_rewrite_single_step(expr const & pre_elem, expr const & pattern) {
check_system("rewrite tactic"); check_system("rewrite tactic");
rewrite_info const & info = get_rewrite_info(elem); rewrite_info const & info = get_rewrite_info(pre_elem);
location const & loc = info.get_location(); location const & loc = info.get_location();
if (loc.is_goal_only()) if (loc.is_goal_only())
return process_rewrite_goal(elem, pre_elem, pattern); return process_rewrite_goal(pre_elem, pattern, *loc.includes_goal());
bool progress = false; bool progress = false;
buffer<expr> hyps; buffer<expr> hyps;
m_g.get_hyps(hyps); m_g.get_hyps(hyps);
for (expr const & h : hyps) { for (expr const & h : hyps) {
if (!loc.includes_hypothesis(local_pp_name(h))) auto occ = loc.includes_hypothesis(local_pp_name(h));
if (!occ)
continue; continue;
if (process_rewrite_hypothesis(h, elem, pre_elem, pattern)) if (process_rewrite_hypothesis(h, pre_elem, pattern, *occ))
progress = true; progress = true;
} }
if (loc.includes_goal() && process_rewrite_goal(elem, pre_elem, pattern)) if (auto occ = loc.includes_goal()) {
if (process_rewrite_goal(pre_elem, pattern, *occ))
progress = true; progress = true;
}
return progress; return progress;
} }
bool process_rewrite_step(expr const & elem, expr const & pre_elem) { bool process_rewrite_step(expr const & elem, expr const & pre_elem) {
lean_assert(is_rewrite_step(elem)); lean_assert(is_rewrite_step(elem));
expr pattern = get_pattern(elem); expr pattern = get_pattern(elem);
regular(m_env, m_ios) << "pattern: " << pattern << "\n"; // regular(m_env, m_ios) << "pattern: " << pattern << "\n";
rewrite_info const & info = get_rewrite_info(elem); rewrite_info const & info = get_rewrite_info(elem);
unsigned num; unsigned num;
switch (info.get_multiplicity()) { switch (info.get_multiplicity()) {
case rewrite_info::Once: case rewrite_info::Once:
return process_rewrite_single_step(elem, pre_elem, pattern); return process_rewrite_single_step(pre_elem, pattern);
case rewrite_info::AtMostN: case rewrite_info::AtMostN:
num = info.num(); num = info.num();
for (unsigned i = 0; i < num; i++) { for (unsigned i = 0; i < num; i++) {
if (!process_rewrite_single_step(elem, pre_elem, pattern)) if (!process_rewrite_single_step(pre_elem, pattern))
return true; return true;
} }
return true; return true;
case rewrite_info::ExactlyN: case rewrite_info::ExactlyN:
num = info.num(); num = info.num();
for (unsigned i = 0; i < num; i++) { for (unsigned i = 0; i < num; i++) {
if (!process_rewrite_single_step(elem, pre_elem, pattern)) if (!process_rewrite_single_step(pre_elem, pattern))
return false; return false;
} }
return true; return true;
case rewrite_info::ZeroOrMore: case rewrite_info::ZeroOrMore:
while (true) { while (true) {
if (!process_rewrite_single_step(elem, pre_elem, pattern)) if (!process_rewrite_single_step(pre_elem, pattern))
return true; return true;
} }
case rewrite_info::OneOrMore: case rewrite_info::OneOrMore:
if (!process_rewrite_single_step(elem, pre_elem, pattern)) if (!process_rewrite_single_step(pre_elem, pattern))
return false; return false;
while (true) { while (true) {
if (!process_rewrite_single_step(elem, pre_elem, pattern)) if (!process_rewrite_single_step(pre_elem, pattern))
return true; return true;
} }
} }
@ -680,7 +722,6 @@ public:
} }
proof_state_seq operator()(buffer<expr> const & elems) { proof_state_seq operator()(buffer<expr> const & elems) {
std::cout << "rewrite_tactic\n";
buffer<expr> new_elems; buffer<expr> new_elems;
elaborate_elems(elems, new_elems); elaborate_elems(elems, new_elems);
@ -692,7 +733,9 @@ public:
} }
} }
return proof_state_seq(m_ps); goals new_gs = cons(m_g, tail(m_ps.get_goals()));
proof_state new_ps(m_ps, new_gs, m_subst, m_ngen);
return proof_state_seq(new_ps);
} }
}; };