feat(library/tactic/rewrite_tactic): add "fold" step
This commit is contained in:
parent
47bd5e53e2
commit
aa70334f8d
2 changed files with 125 additions and 5 deletions
|
@ -164,6 +164,9 @@ static std::string * g_rewrite_elem_opcode = nullptr;
|
|||
static name * g_rewrite_unfold_name = nullptr;
|
||||
static std::string * g_rewrite_unfold_opcode = nullptr;
|
||||
|
||||
static name * g_rewrite_fold_name = nullptr;
|
||||
static std::string * g_rewrite_fold_opcode = nullptr;
|
||||
|
||||
static name * g_rewrite_reduce_name = nullptr;
|
||||
static std::string * g_rewrite_reduce_opcode = nullptr;
|
||||
|
||||
|
@ -205,6 +208,33 @@ reduce_info const & get_rewrite_reduce_info(expr const & e) {
|
|||
return static_cast<rewrite_reduce_macro_cell const*>(macro_def(e).raw())->get_info();
|
||||
}
|
||||
|
||||
typedef reduce_info fold_info;
|
||||
|
||||
class rewrite_fold_macro_cell : public rewrite_core_macro_cell {
|
||||
fold_info m_info;
|
||||
public:
|
||||
rewrite_fold_macro_cell(fold_info const & info):m_info(info) {}
|
||||
virtual name get_name() const { return *g_rewrite_fold_name; }
|
||||
virtual void write(serializer & s) const {
|
||||
s << *g_rewrite_fold_opcode << m_info;
|
||||
}
|
||||
fold_info const & get_info() const { return m_info; }
|
||||
};
|
||||
|
||||
expr mk_rewrite_fold(expr const & e, location const & loc) {
|
||||
macro_definition def(new rewrite_fold_macro_cell(reduce_info(loc)));
|
||||
return mk_macro(def, 1, &e);
|
||||
}
|
||||
|
||||
bool is_rewrite_fold_step(expr const & e) {
|
||||
return is_macro(e) && macro_def(e).get_name() == *g_rewrite_fold_name;
|
||||
}
|
||||
|
||||
fold_info const & get_rewrite_fold_info(expr const & e) {
|
||||
lean_assert(is_rewrite_fold_step(e));
|
||||
return static_cast<rewrite_fold_macro_cell const*>(macro_def(e).raw())->get_info();
|
||||
}
|
||||
|
||||
class rewrite_unfold_macro_cell : public rewrite_core_macro_cell {
|
||||
unfold_info m_info;
|
||||
public:
|
||||
|
@ -301,7 +331,8 @@ rewrite_info const & get_rewrite_info(expr const & e) {
|
|||
|
||||
expr mk_rewrite_tactic_expr(buffer<expr> const & elems) {
|
||||
lean_assert(std::all_of(elems.begin(), elems.end(), [](expr const & e) {
|
||||
return is_rewrite_step(e) || is_rewrite_unfold_step(e) || is_rewrite_reduce_step(e);
|
||||
return is_rewrite_step(e) || is_rewrite_unfold_step(e) ||
|
||||
is_rewrite_reduce_step(e) || is_rewrite_fold_step(e);
|
||||
}));
|
||||
return mk_app(*g_rewrite_tac, mk_expr_list(elems.size(), elems.data()));
|
||||
}
|
||||
|
@ -545,6 +576,78 @@ class rewrite_fn {
|
|||
return process_reduce_step(info.get_names(), info.get_location());
|
||||
}
|
||||
|
||||
optional<expr> fold(expr const & type, expr const & e, occurrence const & occ) {
|
||||
auto ecs = m_elab(m_g, m_ngen.mk_child(), e, false);
|
||||
expr new_e = ecs.first;
|
||||
if (ecs.second)
|
||||
return none_expr(); // contain constraints...
|
||||
optional<expr> unfolded_e = unfold_app(m_env, new_e);
|
||||
if (!unfolded_e)
|
||||
return none_expr();
|
||||
bool use_cache = occ.is_all();
|
||||
unsigned occ_idx = 0;
|
||||
bool found = false;
|
||||
expr new_type =
|
||||
replace(type, [&](expr const & t, unsigned) {
|
||||
if (closed(t)) {
|
||||
constraint_seq cs;
|
||||
if (m_matcher_tc->is_def_eq(t, *unfolded_e, justification(), cs) && !cs) {
|
||||
occ_idx++;
|
||||
if (occ.contains(occ_idx)) {
|
||||
found = true;
|
||||
return some_expr(new_e);
|
||||
}
|
||||
}
|
||||
}
|
||||
return none_expr();
|
||||
}, use_cache);
|
||||
if (found)
|
||||
return some_expr(new_type);
|
||||
else
|
||||
return none_expr();
|
||||
}
|
||||
|
||||
bool process_fold_goal(expr const & e, occurrence const & occ) {
|
||||
if (auto new_type = fold(m_g.get_type(), e, occ)) {
|
||||
replace_goal(*new_type);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool process_fold_hypothesis(expr const & hyp, expr const & e, occurrence const & occ) {
|
||||
if (auto new_hyp_type = fold(mlocal_type(hyp), e, occ)) {
|
||||
replace_hypothesis(hyp, *new_hyp_type);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool process_fold_step(expr const & elem) {
|
||||
lean_assert(is_rewrite_fold_step(elem));
|
||||
location const & loc = get_rewrite_fold_info(elem).get_location();
|
||||
expr const & e = macro_arg(elem, 0);
|
||||
if (loc.is_goal_only())
|
||||
return process_fold_goal(e, *loc.includes_goal());
|
||||
bool progress = false;
|
||||
buffer<expr> hyps;
|
||||
m_g.get_hyps(hyps);
|
||||
for (expr const & h : hyps) {
|
||||
auto occ = loc.includes_hypothesis(local_pp_name(h));
|
||||
if (!occ)
|
||||
continue;
|
||||
if (process_fold_hypothesis(h, e, *occ))
|
||||
progress = true;
|
||||
}
|
||||
if (auto occ = loc.includes_goal()) {
|
||||
if (process_fold_goal(e, *occ))
|
||||
progress = true;
|
||||
}
|
||||
return progress;
|
||||
}
|
||||
|
||||
optional<expr> unify_with(expr const & t, expr const & e) {
|
||||
auto ecs = m_elab(m_g, m_ngen.mk_child(), e, false);
|
||||
expr new_e = ecs.first;
|
||||
|
@ -783,7 +886,8 @@ class rewrite_fn {
|
|||
return unify_result();
|
||||
}
|
||||
|
||||
// target, new_target, H : represents the rewrite (H : target = new_target) for hypothesis and (H : new_target = target) for goals
|
||||
// target, new_target, H : represents the rewrite (H : target = new_target) for hypothesis
|
||||
// and (H : new_target = target) for goals
|
||||
typedef optional<std::tuple<expr, expr, expr>> find_result;
|
||||
|
||||
// Search for \c pattern in \c e. If \c t is a match, then try to unify the type of the rule
|
||||
|
@ -977,6 +1081,8 @@ class rewrite_fn {
|
|||
bool process_step(expr const & elem) {
|
||||
if (is_rewrite_unfold_step(elem)) {
|
||||
return process_unfold_step(elem);
|
||||
} else if (is_rewrite_fold_step(elem)) {
|
||||
return process_fold_step(elem);
|
||||
} else if (is_rewrite_reduce_step(elem)) {
|
||||
return process_reduce_step(elem);
|
||||
} else {
|
||||
|
@ -1067,15 +1173,19 @@ tactic mk_rewrite_tactic(elaborate_fn const & elab, buffer<expr> const & elems)
|
|||
|
||||
void initialize_rewrite_tactic() {
|
||||
g_rewriter_max_iterations = new name{"rewriter", "max_iter"};
|
||||
register_unsigned_option(*g_rewriter_max_iterations, LEAN_DEFAULT_REWRITER_MAX_ITERATIONS, "(rewriter tactic) maximum number of iterations");
|
||||
register_unsigned_option(*g_rewriter_max_iterations, LEAN_DEFAULT_REWRITER_MAX_ITERATIONS,
|
||||
"(rewriter tactic) maximum number of iterations");
|
||||
g_rewriter_syntactic = new name{"rewriter", "syntactic"};
|
||||
register_bool_option(*g_rewriter_syntactic, LEAN_DEFAULT_REWRITER_SYNTACTIC, "(rewriter tactic) if true tactic will not unfold any constant when performing pattern matching");
|
||||
register_bool_option(*g_rewriter_syntactic, LEAN_DEFAULT_REWRITER_SYNTACTIC,
|
||||
"(rewriter tactic) if true tactic will not unfold any constant when performing pattern matching");
|
||||
name rewrite_tac_name{"tactic", "rewrite_tac"};
|
||||
g_rewrite_tac = new expr(Const(rewrite_tac_name));
|
||||
g_rewrite_reduce_name = new name("rewrite_reduce");
|
||||
g_rewrite_reduce_opcode = new std::string("RWR");
|
||||
g_rewrite_unfold_name = new name("rewrite_unfold");
|
||||
g_rewrite_unfold_opcode = new std::string("RWU");
|
||||
g_rewrite_fold_name = new name("rewrite_fold");
|
||||
g_rewrite_fold_opcode = new std::string("RWF");
|
||||
g_rewrite_elem_name = new name("rewrite_element");
|
||||
g_rewrite_elem_opcode = new std::string("RWE");
|
||||
register_macro_deserializer(*g_rewrite_reduce_opcode,
|
||||
|
@ -1089,6 +1199,14 @@ void initialize_rewrite_tactic() {
|
|||
else
|
||||
return mk_rewrite_reduce_to(args[0], info.get_location());
|
||||
});
|
||||
register_macro_deserializer(*g_rewrite_fold_opcode,
|
||||
[](deserializer & d, unsigned num, expr const * args) {
|
||||
if (num != 1)
|
||||
throw corrupted_stream_exception();
|
||||
fold_info info;
|
||||
d >> info;
|
||||
return mk_rewrite_fold(args[0], info.get_location());
|
||||
});
|
||||
register_macro_deserializer(*g_rewrite_unfold_opcode,
|
||||
[](deserializer & d, unsigned num, expr const *) {
|
||||
if (num != 0)
|
||||
|
@ -1112,7 +1230,8 @@ void initialize_rewrite_tactic() {
|
|||
buffer<expr> args;
|
||||
get_tactic_expr_list_elements(app_arg(e), args, "invalid 'rewrite' tactic, invalid argument");
|
||||
for (expr const & arg : args) {
|
||||
if (!is_rewrite_step(arg) && !is_rewrite_unfold_step(arg) && !is_rewrite_reduce_step(arg))
|
||||
if (!is_rewrite_step(arg) && !is_rewrite_unfold_step(arg) &&
|
||||
!is_rewrite_reduce_step(arg) && !is_rewrite_fold_step(arg))
|
||||
throw expr_to_tactic_exception(e, "invalid 'rewrite' tactic, invalid argument");
|
||||
}
|
||||
return mk_rewrite_tactic(elab, args);
|
||||
|
|
|
@ -12,6 +12,7 @@ namespace lean {
|
|||
expr mk_rewrite_unfold(list<name> const & ns, location const & loc);
|
||||
expr mk_rewrite_reduce(location const & loc);
|
||||
expr mk_rewrite_reduce_to(expr const & e, location const & loc);
|
||||
expr mk_rewrite_fold(expr const & e, location const & loc);
|
||||
expr mk_rewrite_once(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
|
||||
expr mk_rewrite_zero_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
|
||||
expr mk_rewrite_one_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
|
||||
|
|
Loading…
Reference in a new issue