feat(library/tactic/rewrite_tactic): add "fold" step

This commit is contained in:
Leonardo de Moura 2015-02-06 15:21:49 -08:00
parent 47bd5e53e2
commit aa70334f8d
2 changed files with 125 additions and 5 deletions

View file

@ -164,6 +164,9 @@ static std::string * g_rewrite_elem_opcode = nullptr;
static name * g_rewrite_unfold_name = nullptr;
static std::string * g_rewrite_unfold_opcode = nullptr;
static name * g_rewrite_fold_name = nullptr;
static std::string * g_rewrite_fold_opcode = nullptr;
static name * g_rewrite_reduce_name = nullptr;
static std::string * g_rewrite_reduce_opcode = nullptr;
@ -205,6 +208,33 @@ reduce_info const & get_rewrite_reduce_info(expr const & e) {
return static_cast<rewrite_reduce_macro_cell const*>(macro_def(e).raw())->get_info();
}
typedef reduce_info fold_info;
class rewrite_fold_macro_cell : public rewrite_core_macro_cell {
fold_info m_info;
public:
rewrite_fold_macro_cell(fold_info const & info):m_info(info) {}
virtual name get_name() const { return *g_rewrite_fold_name; }
virtual void write(serializer & s) const {
s << *g_rewrite_fold_opcode << m_info;
}
fold_info const & get_info() const { return m_info; }
};
expr mk_rewrite_fold(expr const & e, location const & loc) {
macro_definition def(new rewrite_fold_macro_cell(reduce_info(loc)));
return mk_macro(def, 1, &e);
}
bool is_rewrite_fold_step(expr const & e) {
return is_macro(e) && macro_def(e).get_name() == *g_rewrite_fold_name;
}
fold_info const & get_rewrite_fold_info(expr const & e) {
lean_assert(is_rewrite_fold_step(e));
return static_cast<rewrite_fold_macro_cell const*>(macro_def(e).raw())->get_info();
}
class rewrite_unfold_macro_cell : public rewrite_core_macro_cell {
unfold_info m_info;
public:
@ -301,7 +331,8 @@ rewrite_info const & get_rewrite_info(expr const & e) {
expr mk_rewrite_tactic_expr(buffer<expr> const & elems) {
lean_assert(std::all_of(elems.begin(), elems.end(), [](expr const & e) {
return is_rewrite_step(e) || is_rewrite_unfold_step(e) || is_rewrite_reduce_step(e);
return is_rewrite_step(e) || is_rewrite_unfold_step(e) ||
is_rewrite_reduce_step(e) || is_rewrite_fold_step(e);
}));
return mk_app(*g_rewrite_tac, mk_expr_list(elems.size(), elems.data()));
}
@ -545,6 +576,78 @@ class rewrite_fn {
return process_reduce_step(info.get_names(), info.get_location());
}
optional<expr> fold(expr const & type, expr const & e, occurrence const & occ) {
auto ecs = m_elab(m_g, m_ngen.mk_child(), e, false);
expr new_e = ecs.first;
if (ecs.second)
return none_expr(); // contain constraints...
optional<expr> unfolded_e = unfold_app(m_env, new_e);
if (!unfolded_e)
return none_expr();
bool use_cache = occ.is_all();
unsigned occ_idx = 0;
bool found = false;
expr new_type =
replace(type, [&](expr const & t, unsigned) {
if (closed(t)) {
constraint_seq cs;
if (m_matcher_tc->is_def_eq(t, *unfolded_e, justification(), cs) && !cs) {
occ_idx++;
if (occ.contains(occ_idx)) {
found = true;
return some_expr(new_e);
}
}
}
return none_expr();
}, use_cache);
if (found)
return some_expr(new_type);
else
return none_expr();
}
bool process_fold_goal(expr const & e, occurrence const & occ) {
if (auto new_type = fold(m_g.get_type(), e, occ)) {
replace_goal(*new_type);
return true;
} else {
return false;
}
}
bool process_fold_hypothesis(expr const & hyp, expr const & e, occurrence const & occ) {
if (auto new_hyp_type = fold(mlocal_type(hyp), e, occ)) {
replace_hypothesis(hyp, *new_hyp_type);
return true;
} else {
return false;
}
}
bool process_fold_step(expr const & elem) {
lean_assert(is_rewrite_fold_step(elem));
location const & loc = get_rewrite_fold_info(elem).get_location();
expr const & e = macro_arg(elem, 0);
if (loc.is_goal_only())
return process_fold_goal(e, *loc.includes_goal());
bool progress = false;
buffer<expr> hyps;
m_g.get_hyps(hyps);
for (expr const & h : hyps) {
auto occ = loc.includes_hypothesis(local_pp_name(h));
if (!occ)
continue;
if (process_fold_hypothesis(h, e, *occ))
progress = true;
}
if (auto occ = loc.includes_goal()) {
if (process_fold_goal(e, *occ))
progress = true;
}
return progress;
}
optional<expr> unify_with(expr const & t, expr const & e) {
auto ecs = m_elab(m_g, m_ngen.mk_child(), e, false);
expr new_e = ecs.first;
@ -783,7 +886,8 @@ class rewrite_fn {
return unify_result();
}
// target, new_target, H : represents the rewrite (H : target = new_target) for hypothesis and (H : new_target = target) for goals
// target, new_target, H : represents the rewrite (H : target = new_target) for hypothesis
// and (H : new_target = target) for goals
typedef optional<std::tuple<expr, expr, expr>> find_result;
// Search for \c pattern in \c e. If \c t is a match, then try to unify the type of the rule
@ -977,6 +1081,8 @@ class rewrite_fn {
bool process_step(expr const & elem) {
if (is_rewrite_unfold_step(elem)) {
return process_unfold_step(elem);
} else if (is_rewrite_fold_step(elem)) {
return process_fold_step(elem);
} else if (is_rewrite_reduce_step(elem)) {
return process_reduce_step(elem);
} else {
@ -1067,15 +1173,19 @@ tactic mk_rewrite_tactic(elaborate_fn const & elab, buffer<expr> const & elems)
void initialize_rewrite_tactic() {
g_rewriter_max_iterations = new name{"rewriter", "max_iter"};
register_unsigned_option(*g_rewriter_max_iterations, LEAN_DEFAULT_REWRITER_MAX_ITERATIONS, "(rewriter tactic) maximum number of iterations");
register_unsigned_option(*g_rewriter_max_iterations, LEAN_DEFAULT_REWRITER_MAX_ITERATIONS,
"(rewriter tactic) maximum number of iterations");
g_rewriter_syntactic = new name{"rewriter", "syntactic"};
register_bool_option(*g_rewriter_syntactic, LEAN_DEFAULT_REWRITER_SYNTACTIC, "(rewriter tactic) if true tactic will not unfold any constant when performing pattern matching");
register_bool_option(*g_rewriter_syntactic, LEAN_DEFAULT_REWRITER_SYNTACTIC,
"(rewriter tactic) if true tactic will not unfold any constant when performing pattern matching");
name rewrite_tac_name{"tactic", "rewrite_tac"};
g_rewrite_tac = new expr(Const(rewrite_tac_name));
g_rewrite_reduce_name = new name("rewrite_reduce");
g_rewrite_reduce_opcode = new std::string("RWR");
g_rewrite_unfold_name = new name("rewrite_unfold");
g_rewrite_unfold_opcode = new std::string("RWU");
g_rewrite_fold_name = new name("rewrite_fold");
g_rewrite_fold_opcode = new std::string("RWF");
g_rewrite_elem_name = new name("rewrite_element");
g_rewrite_elem_opcode = new std::string("RWE");
register_macro_deserializer(*g_rewrite_reduce_opcode,
@ -1089,6 +1199,14 @@ void initialize_rewrite_tactic() {
else
return mk_rewrite_reduce_to(args[0], info.get_location());
});
register_macro_deserializer(*g_rewrite_fold_opcode,
[](deserializer & d, unsigned num, expr const * args) {
if (num != 1)
throw corrupted_stream_exception();
fold_info info;
d >> info;
return mk_rewrite_fold(args[0], info.get_location());
});
register_macro_deserializer(*g_rewrite_unfold_opcode,
[](deserializer & d, unsigned num, expr const *) {
if (num != 0)
@ -1112,7 +1230,8 @@ void initialize_rewrite_tactic() {
buffer<expr> args;
get_tactic_expr_list_elements(app_arg(e), args, "invalid 'rewrite' tactic, invalid argument");
for (expr const & arg : args) {
if (!is_rewrite_step(arg) && !is_rewrite_unfold_step(arg) && !is_rewrite_reduce_step(arg))
if (!is_rewrite_step(arg) && !is_rewrite_unfold_step(arg) &&
!is_rewrite_reduce_step(arg) && !is_rewrite_fold_step(arg))
throw expr_to_tactic_exception(e, "invalid 'rewrite' tactic, invalid argument");
}
return mk_rewrite_tactic(elab, args);

View file

@ -12,6 +12,7 @@ namespace lean {
expr mk_rewrite_unfold(list<name> const & ns, location const & loc);
expr mk_rewrite_reduce(location const & loc);
expr mk_rewrite_reduce_to(expr const & e, location const & loc);
expr mk_rewrite_fold(expr const & e, location const & loc);
expr mk_rewrite_once(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
expr mk_rewrite_zero_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
expr mk_rewrite_one_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);