feat(library/tactic/rewrite_tactic): rewrite goal

This commit is contained in:
Leonardo de Moura 2015-02-04 15:17:58 -08:00
parent 599de0271b
commit e5381679d6
3 changed files with 86 additions and 25 deletions

View file

@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "kernel/replace_fn.h"
#include "library/kernel_serializer.h"
#include "library/tactic/location.h"
@ -110,4 +111,16 @@ deserializer & operator>>(deserializer & d, location & loc) {
loc.m_hyps = to_list(tmp);
return d;
}
expr replace_occurrences(expr const & e, expr const & t, occurrence const & occ, unsigned idx) {
unsigned occ_idx = 0;
return replace(e, [&](expr const & e, unsigned offset) {
if (e == t) {
occ_idx++;
if (occ.contains(occ_idx))
return some_expr(mk_var(offset+idx));
}
return none_expr();
});
}
}

View file

@ -49,6 +49,11 @@ public:
friend deserializer & operator>>(deserializer & d, occurrence & o);
};
/** \brief Replace occurrences of \c t in \c e with the free variable #vidx.
The j-th occurrence is replaced iff occ.contains(j)
*/
expr replace_occurrences(expr const & e, expr const & t, occurrence const & occ, unsigned vidx);
class location {
public:
enum kind { Everywhere, GoalOnly, AllHypotheses, Hypotheses, GoalHypotheses };

View file

@ -10,6 +10,7 @@ Author: Leonardo de Moura
#include "kernel/instantiate.h"
#include "kernel/replace_fn.h"
#include "kernel/for_each_fn.h"
#include "kernel/inductive/inductive.h"
#include "library/kernel_serializer.h"
#include "library/reducible.h"
#include "library/util.h"
@ -18,6 +19,7 @@ Author: Leonardo de Moura
#include "library/local_context.h"
#include "library/unifier.h"
#include "library/util.h"
#include "library/constants.h"
#include "library/generic_exception.h"
#include "library/tactic/rewrite_tactic.h"
#include "library/tactic/expr_to_tactic.h"
@ -498,6 +500,7 @@ class rewrite_fn {
// rule, new_t
typedef optional<pair<expr, expr>> unify_result;
// When successful, the result is the pair (H, new_t) where (H : new_t = t)
unify_result unify_target(expr const & t, expr const & pre_elem) {
try {
expr rule = get_rewrite_rule(pre_elem);
@ -545,9 +548,9 @@ class rewrite_fn {
expr lhs = app_arg(app_fn(rule_type));
expr rhs = app_arg(rule_type);
if (symm) {
rule = mk_symm(*m_tc, rule);
return unify_result(rule, lhs);
} else {
rule = mk_symm(*m_tc, rule);
return unify_result(rule, rhs);
}
}
@ -555,7 +558,7 @@ class rewrite_fn {
return unify_result();
}
// target, rule, new_target : represents the rewrite rule : target = new_target
// target, new_target, H : represents the rewrite H : target = new_target
typedef optional<std::tuple<expr, expr, expr>> find_result;
// Search for \c pattern in \c e. If \c t is a match, then try to unify the type of the rule
@ -575,7 +578,7 @@ class rewrite_fn {
reset_subst();
if (r) {
if (auto p = unify_target(t, pre_elem)) {
result = std::make_tuple(t, p->first, p->second);
result = std::make_tuple(t, p->second, p->first);
return false;
}
}
@ -585,73 +588,112 @@ class rewrite_fn {
return result;
}
bool process_rewrite_hypothesis(expr const & hyp, expr const & elem, expr const & pre_elem, expr const & pattern) {
/** Given (a, b, P[a], Heq : b = a, occ), return (P[b], M : P[b], H : P[a])
where M is a metavariable application of a fresh metavariable, and H is a witness (based on M) for P[a].
\remark occ is used to select which occurrences of a in P[a] will be replaced with b
\remark the witness \c H is used using eq.rec
*/
std::tuple<expr, expr, expr> apply_rewrite(expr const & a, expr const & b, expr const & Pa, expr const & Heq, occurrence const & occ) {
bool has_dep_elim = inductive::has_dep_elim(m_env, get_eq_name());
unsigned vidx = has_dep_elim ? 1 : 0;
expr Px = replace_occurrences(Pa, a, occ, vidx);
expr Pb = instantiate(Px, vidx, b);
expr A = m_tc->infer(a).first;
level l1 = sort_level(m_tc->ensure_type(Pa).first);
level l2 = sort_level(m_tc->ensure_type(A).first);
expr M = m_g.mk_meta(m_ngen.next(), Pb);
expr H;
if (has_dep_elim) {
expr Haeqx = mk_app(mk_constant(get_eq_name(), {l1}), A, b, mk_var(0));
expr P = mk_lambda("x", A, mk_lambda("H", Haeqx, Px));
H = mk_app({mk_constant(get_eq_rec_name(), {l1, l2}), A, b, P, M, a, Heq});
} else {
H = mk_app({mk_constant(get_eq_rec_name(), {l1, l2}), A, b, mk_lambda("x", A, Px), M, a, Heq});
}
return std::make_tuple(Pb, M, H);
}
bool process_rewrite_hypothesis(expr const & hyp, expr const & pre_elem, expr const & pattern, occurrence const & occ) {
// TODO(Leo)
return false;
}
bool process_rewrite_goal(expr const & elem, expr const & pre_elem, expr const & pattern) {
expr goal_type = m_g.get_type();
if (auto it = find_target(goal_type, pattern, pre_elem)) {
regular(m_env, m_ios) << "FOUND\n" << std::get<0>(*it) << "\n==\n" << std::get<1>(*it) << "\n==>\n" << std::get<2>(*it) << "\n";
// TODO(Leo)
bool process_rewrite_goal(expr const & pre_elem, expr const & pattern, occurrence const & occ) {
expr Pa = m_g.get_type();
if (auto it = find_target(Pa, pattern, pre_elem)) {
expr a, Heq, b;
std::tie(a, b, Heq) = *it;
expr Pb, M, H;
std::tie(Pb, M, H) = apply_rewrite(a, b, Pa, Heq, occ);
goal new_g(M, Pb);
expr val = m_g.abstract(H);
m_subst.assign(m_g.get_name(), val);
update_goal(new_g);
// regular(m_env, m_ios) << "FOUND\n" << a << "\n==>\n" << b << "\nWITH\n" << Heq << "\n";
// regular(m_env, m_ios) << H << "\n";
return true;
}
return false;
}
bool process_rewrite_single_step(expr const & elem, expr const & pre_elem, expr const & pattern) {
bool process_rewrite_single_step(expr const & pre_elem, expr const & pattern) {
check_system("rewrite tactic");
rewrite_info const & info = get_rewrite_info(elem);
rewrite_info const & info = get_rewrite_info(pre_elem);
location const & loc = info.get_location();
if (loc.is_goal_only())
return process_rewrite_goal(elem, pre_elem, pattern);
return process_rewrite_goal(pre_elem, pattern, *loc.includes_goal());
bool progress = false;
buffer<expr> hyps;
m_g.get_hyps(hyps);
for (expr const & h : hyps) {
if (!loc.includes_hypothesis(local_pp_name(h)))
auto occ = loc.includes_hypothesis(local_pp_name(h));
if (!occ)
continue;
if (process_rewrite_hypothesis(h, elem, pre_elem, pattern))
if (process_rewrite_hypothesis(h, pre_elem, pattern, *occ))
progress = true;
}
if (auto occ = loc.includes_goal()) {
if (process_rewrite_goal(pre_elem, pattern, *occ))
progress = true;
}
if (loc.includes_goal() && process_rewrite_goal(elem, pre_elem, pattern))
progress = true;
return progress;
}
bool process_rewrite_step(expr const & elem, expr const & pre_elem) {
lean_assert(is_rewrite_step(elem));
expr pattern = get_pattern(elem);
regular(m_env, m_ios) << "pattern: " << pattern << "\n";
// regular(m_env, m_ios) << "pattern: " << pattern << "\n";
rewrite_info const & info = get_rewrite_info(elem);
unsigned num;
switch (info.get_multiplicity()) {
case rewrite_info::Once:
return process_rewrite_single_step(elem, pre_elem, pattern);
return process_rewrite_single_step(pre_elem, pattern);
case rewrite_info::AtMostN:
num = info.num();
for (unsigned i = 0; i < num; i++) {
if (!process_rewrite_single_step(elem, pre_elem, pattern))
if (!process_rewrite_single_step(pre_elem, pattern))
return true;
}
return true;
case rewrite_info::ExactlyN:
num = info.num();
for (unsigned i = 0; i < num; i++) {
if (!process_rewrite_single_step(elem, pre_elem, pattern))
if (!process_rewrite_single_step(pre_elem, pattern))
return false;
}
return true;
case rewrite_info::ZeroOrMore:
while (true) {
if (!process_rewrite_single_step(elem, pre_elem, pattern))
if (!process_rewrite_single_step(pre_elem, pattern))
return true;
}
case rewrite_info::OneOrMore:
if (!process_rewrite_single_step(elem, pre_elem, pattern))
if (!process_rewrite_single_step(pre_elem, pattern))
return false;
while (true) {
if (!process_rewrite_single_step(elem, pre_elem, pattern))
if (!process_rewrite_single_step(pre_elem, pattern))
return true;
}
}
@ -680,7 +722,6 @@ public:
}
proof_state_seq operator()(buffer<expr> const & elems) {
std::cout << "rewrite_tactic\n";
buffer<expr> new_elems;
elaborate_elems(elems, new_elems);
@ -692,7 +733,9 @@ public:
}
}
return proof_state_seq(m_ps);
goals new_gs = cons(m_g, tail(m_ps.get_goals()));
proof_state new_ps(m_ps, new_gs, m_subst, m_ngen);
return proof_state_seq(new_ps);
}
};