From aa70334f8d71a58020e1f57b9272677d97e69da2 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 6 Feb 2015 15:21:49 -0800 Subject: [PATCH] feat(library/tactic/rewrite_tactic): add "fold" step --- src/library/tactic/rewrite_tactic.cpp | 129 +++++++++++++++++++++++++- src/library/tactic/rewrite_tactic.h | 1 + 2 files changed, 125 insertions(+), 5 deletions(-) diff --git a/src/library/tactic/rewrite_tactic.cpp b/src/library/tactic/rewrite_tactic.cpp index d4397f1bb..a6cbe11f7 100644 --- a/src/library/tactic/rewrite_tactic.cpp +++ b/src/library/tactic/rewrite_tactic.cpp @@ -164,6 +164,9 @@ static std::string * g_rewrite_elem_opcode = nullptr; static name * g_rewrite_unfold_name = nullptr; static std::string * g_rewrite_unfold_opcode = nullptr; +static name * g_rewrite_fold_name = nullptr; +static std::string * g_rewrite_fold_opcode = nullptr; + static name * g_rewrite_reduce_name = nullptr; static std::string * g_rewrite_reduce_opcode = nullptr; @@ -205,6 +208,33 @@ reduce_info const & get_rewrite_reduce_info(expr const & e) { return static_cast(macro_def(e).raw())->get_info(); } +typedef reduce_info fold_info; + +class rewrite_fold_macro_cell : public rewrite_core_macro_cell { + fold_info m_info; +public: + rewrite_fold_macro_cell(fold_info const & info):m_info(info) {} + virtual name get_name() const { return *g_rewrite_fold_name; } + virtual void write(serializer & s) const { + s << *g_rewrite_fold_opcode << m_info; + } + fold_info const & get_info() const { return m_info; } +}; + +expr mk_rewrite_fold(expr const & e, location const & loc) { + macro_definition def(new rewrite_fold_macro_cell(reduce_info(loc))); + return mk_macro(def, 1, &e); +} + +bool is_rewrite_fold_step(expr const & e) { + return is_macro(e) && macro_def(e).get_name() == *g_rewrite_fold_name; +} + +fold_info const & get_rewrite_fold_info(expr const & e) { + lean_assert(is_rewrite_fold_step(e)); + return static_cast(macro_def(e).raw())->get_info(); +} + class rewrite_unfold_macro_cell : public rewrite_core_macro_cell { unfold_info m_info; public: @@ -301,7 +331,8 @@ rewrite_info const & get_rewrite_info(expr const & e) { expr mk_rewrite_tactic_expr(buffer const & elems) { lean_assert(std::all_of(elems.begin(), elems.end(), [](expr const & e) { - return is_rewrite_step(e) || is_rewrite_unfold_step(e) || is_rewrite_reduce_step(e); + return is_rewrite_step(e) || is_rewrite_unfold_step(e) || + is_rewrite_reduce_step(e) || is_rewrite_fold_step(e); })); return mk_app(*g_rewrite_tac, mk_expr_list(elems.size(), elems.data())); } @@ -545,6 +576,78 @@ class rewrite_fn { return process_reduce_step(info.get_names(), info.get_location()); } + optional fold(expr const & type, expr const & e, occurrence const & occ) { + auto ecs = m_elab(m_g, m_ngen.mk_child(), e, false); + expr new_e = ecs.first; + if (ecs.second) + return none_expr(); // contain constraints... + optional unfolded_e = unfold_app(m_env, new_e); + if (!unfolded_e) + return none_expr(); + bool use_cache = occ.is_all(); + unsigned occ_idx = 0; + bool found = false; + expr new_type = + replace(type, [&](expr const & t, unsigned) { + if (closed(t)) { + constraint_seq cs; + if (m_matcher_tc->is_def_eq(t, *unfolded_e, justification(), cs) && !cs) { + occ_idx++; + if (occ.contains(occ_idx)) { + found = true; + return some_expr(new_e); + } + } + } + return none_expr(); + }, use_cache); + if (found) + return some_expr(new_type); + else + return none_expr(); + } + + bool process_fold_goal(expr const & e, occurrence const & occ) { + if (auto new_type = fold(m_g.get_type(), e, occ)) { + replace_goal(*new_type); + return true; + } else { + return false; + } + } + + bool process_fold_hypothesis(expr const & hyp, expr const & e, occurrence const & occ) { + if (auto new_hyp_type = fold(mlocal_type(hyp), e, occ)) { + replace_hypothesis(hyp, *new_hyp_type); + return true; + } else { + return false; + } + } + + bool process_fold_step(expr const & elem) { + lean_assert(is_rewrite_fold_step(elem)); + location const & loc = get_rewrite_fold_info(elem).get_location(); + expr const & e = macro_arg(elem, 0); + if (loc.is_goal_only()) + return process_fold_goal(e, *loc.includes_goal()); + bool progress = false; + buffer hyps; + m_g.get_hyps(hyps); + for (expr const & h : hyps) { + auto occ = loc.includes_hypothesis(local_pp_name(h)); + if (!occ) + continue; + if (process_fold_hypothesis(h, e, *occ)) + progress = true; + } + if (auto occ = loc.includes_goal()) { + if (process_fold_goal(e, *occ)) + progress = true; + } + return progress; + } + optional unify_with(expr const & t, expr const & e) { auto ecs = m_elab(m_g, m_ngen.mk_child(), e, false); expr new_e = ecs.first; @@ -783,7 +886,8 @@ class rewrite_fn { return unify_result(); } - // target, new_target, H : represents the rewrite (H : target = new_target) for hypothesis and (H : new_target = target) for goals + // target, new_target, H : represents the rewrite (H : target = new_target) for hypothesis + // and (H : new_target = target) for goals typedef optional> find_result; // Search for \c pattern in \c e. If \c t is a match, then try to unify the type of the rule @@ -977,6 +1081,8 @@ class rewrite_fn { bool process_step(expr const & elem) { if (is_rewrite_unfold_step(elem)) { return process_unfold_step(elem); + } else if (is_rewrite_fold_step(elem)) { + return process_fold_step(elem); } else if (is_rewrite_reduce_step(elem)) { return process_reduce_step(elem); } else { @@ -1067,15 +1173,19 @@ tactic mk_rewrite_tactic(elaborate_fn const & elab, buffer const & elems) void initialize_rewrite_tactic() { g_rewriter_max_iterations = new name{"rewriter", "max_iter"}; - register_unsigned_option(*g_rewriter_max_iterations, LEAN_DEFAULT_REWRITER_MAX_ITERATIONS, "(rewriter tactic) maximum number of iterations"); + register_unsigned_option(*g_rewriter_max_iterations, LEAN_DEFAULT_REWRITER_MAX_ITERATIONS, + "(rewriter tactic) maximum number of iterations"); g_rewriter_syntactic = new name{"rewriter", "syntactic"}; - register_bool_option(*g_rewriter_syntactic, LEAN_DEFAULT_REWRITER_SYNTACTIC, "(rewriter tactic) if true tactic will not unfold any constant when performing pattern matching"); + register_bool_option(*g_rewriter_syntactic, LEAN_DEFAULT_REWRITER_SYNTACTIC, + "(rewriter tactic) if true tactic will not unfold any constant when performing pattern matching"); name rewrite_tac_name{"tactic", "rewrite_tac"}; g_rewrite_tac = new expr(Const(rewrite_tac_name)); g_rewrite_reduce_name = new name("rewrite_reduce"); g_rewrite_reduce_opcode = new std::string("RWR"); g_rewrite_unfold_name = new name("rewrite_unfold"); g_rewrite_unfold_opcode = new std::string("RWU"); + g_rewrite_fold_name = new name("rewrite_fold"); + g_rewrite_fold_opcode = new std::string("RWF"); g_rewrite_elem_name = new name("rewrite_element"); g_rewrite_elem_opcode = new std::string("RWE"); register_macro_deserializer(*g_rewrite_reduce_opcode, @@ -1089,6 +1199,14 @@ void initialize_rewrite_tactic() { else return mk_rewrite_reduce_to(args[0], info.get_location()); }); + register_macro_deserializer(*g_rewrite_fold_opcode, + [](deserializer & d, unsigned num, expr const * args) { + if (num != 1) + throw corrupted_stream_exception(); + fold_info info; + d >> info; + return mk_rewrite_fold(args[0], info.get_location()); + }); register_macro_deserializer(*g_rewrite_unfold_opcode, [](deserializer & d, unsigned num, expr const *) { if (num != 0) @@ -1112,7 +1230,8 @@ void initialize_rewrite_tactic() { buffer args; get_tactic_expr_list_elements(app_arg(e), args, "invalid 'rewrite' tactic, invalid argument"); for (expr const & arg : args) { - if (!is_rewrite_step(arg) && !is_rewrite_unfold_step(arg) && !is_rewrite_reduce_step(arg)) + if (!is_rewrite_step(arg) && !is_rewrite_unfold_step(arg) && + !is_rewrite_reduce_step(arg) && !is_rewrite_fold_step(arg)) throw expr_to_tactic_exception(e, "invalid 'rewrite' tactic, invalid argument"); } return mk_rewrite_tactic(elab, args); diff --git a/src/library/tactic/rewrite_tactic.h b/src/library/tactic/rewrite_tactic.h index 08bcfbb10..c2e1870bb 100644 --- a/src/library/tactic/rewrite_tactic.h +++ b/src/library/tactic/rewrite_tactic.h @@ -12,6 +12,7 @@ namespace lean { expr mk_rewrite_unfold(list const & ns, location const & loc); expr mk_rewrite_reduce(location const & loc); expr mk_rewrite_reduce_to(expr const & e, location const & loc); +expr mk_rewrite_fold(expr const & e, location const & loc); expr mk_rewrite_once(optional const & pattern, expr const & H, bool symm, location const & loc); expr mk_rewrite_zero_or_more(optional const & pattern, expr const & H, bool symm, location const & loc); expr mk_rewrite_one_or_more(optional const & pattern, expr const & H, bool symm, location const & loc);