diff --git a/src/library/tactic/rewrite_tactic.cpp b/src/library/tactic/rewrite_tactic.cpp index a555ed2ac..feb5803a7 100644 --- a/src/library/tactic/rewrite_tactic.cpp +++ b/src/library/tactic/rewrite_tactic.cpp @@ -49,19 +49,31 @@ public: unfold_info(name const & n, location const & loc):m_name(n), m_location(loc) {} name const & get_name() const { return m_name; } location const & get_location() const { return m_location; } - friend serializer & operator<<(serializer & s, unfold_info const & elem); - friend deserializer & operator>>(deserializer & d, unfold_info & e); + friend serializer & operator<<(serializer & s, unfold_info const & e) { + s << e.m_name << e.m_location; + return s; + } + friend deserializer & operator>>(deserializer & d, unfold_info & e) { + d >> e.m_name >> e.m_location; + return d; + } }; -serializer & operator<<(serializer & s, unfold_info const & e) { - s << e.m_name << e.m_location; - return s; -} - -deserializer & operator>>(deserializer & d, unfold_info & e) { - d >> e.m_name >> e.m_location; - return d; -} +class reduce_info { + location m_location; +public: + reduce_info() {} + reduce_info(location const & loc):m_location(loc) {} + location const & get_location() const { return m_location; } + friend serializer & operator<<(serializer & s, reduce_info const & e) { + s << e.m_location; + return s; + } + friend deserializer & operator>>(deserializer & d, reduce_info & e) { + d >> e.m_location; + return d; + } +}; class rewrite_info { public: @@ -115,26 +127,23 @@ public: location const & get_location() const { return m_location; } - friend serializer & operator<<(serializer & s, rewrite_info const & elem); - friend deserializer & operator>>(deserializer & d, rewrite_info & e); + friend serializer & operator<<(serializer & s, rewrite_info const & e) { + s << e.m_symm << static_cast(e.m_multiplicity) << e.m_location; + if (e.has_num()) + s << e.num(); + return s; + } + + friend deserializer & operator>>(deserializer & d, rewrite_info & e) { + char multp; + d >> e.m_symm >> multp >> e.m_location; + e.m_multiplicity = static_cast(multp); + if (e.has_num()) + e.m_num = d.read_unsigned(); + return d; + } }; -serializer & operator<<(serializer & s, rewrite_info const & e) { - s << e.m_symm << static_cast(e.m_multiplicity) << e.m_location; - if (e.has_num()) - s << e.num(); - return s; -} - -deserializer & operator>>(deserializer & d, rewrite_info & e) { - char multp; - d >> e.m_symm >> multp >> e.m_location; - e.m_multiplicity = static_cast(multp); - if (e.has_num()) - e.m_num = d.read_unsigned(); - return d; -} - static expr * g_rewrite_tac = nullptr; static name * g_rewrite_elem_name = nullptr; @@ -143,16 +152,52 @@ static std::string * g_rewrite_elem_opcode = nullptr; static name * g_rewrite_unfold_name = nullptr; static std::string * g_rewrite_unfold_opcode = nullptr; -[[ noreturn ]] static void throw_ru_ex() { throw exception("unexpected occurrence of 'rewrite unfold' expression"); } -[[ noreturn ]] static void throw_re_ex() { throw exception("unexpected occurrence of 'rewrite element' expression"); } +static name * g_rewrite_reduce_name = nullptr; +static std::string * g_rewrite_reduce_opcode = nullptr; -class rewrite_unfold_macro_cell : public macro_definition_cell { +[[ noreturn ]] static void throw_re_ex() { throw exception("unexpected occurrence of 'rewrite' expression"); } + +class rewrite_core_macro_cell : public macro_definition_cell { +public: + virtual pair get_type(expr const &, extension_context &) const { throw_re_ex(); } + virtual optional expand(expr const &, extension_context &) const { throw_re_ex(); } +}; + +class rewrite_reduce_macro_cell : public rewrite_core_macro_cell { + reduce_info m_info; +public: + rewrite_reduce_macro_cell(reduce_info const & info):m_info(info) {} + virtual name get_name() const { return *g_rewrite_reduce_name; } + virtual void write(serializer & s) const { + s << *g_rewrite_reduce_opcode << m_info; + } + reduce_info const & get_info() const { return m_info; } +}; + +expr mk_rewrite_reduce(location const & loc) { + macro_definition def(new rewrite_reduce_macro_cell(reduce_info(loc))); + return mk_macro(def); +} + +expr mk_rewrite_reduce_to(expr const & e, location const & loc) { + macro_definition def(new rewrite_reduce_macro_cell(reduce_info(loc))); + return mk_macro(def, 1, &e); +} + +bool is_rewrite_reduce_step(expr const & e) { + return is_macro(e) && macro_def(e).get_name() == *g_rewrite_reduce_name; +} + +reduce_info const & get_rewrite_reduce_info(expr const & e) { + lean_assert(is_rewrite_reduce_step(e)); + return static_cast(macro_def(e).raw())->get_info(); +} + +class rewrite_unfold_macro_cell : public rewrite_core_macro_cell { unfold_info m_info; public: rewrite_unfold_macro_cell(unfold_info const & info):m_info(info) {} virtual name get_name() const { return *g_rewrite_unfold_name; } - virtual pair get_type(expr const &, extension_context &) const { throw_ru_ex(); } - virtual optional expand(expr const &, extension_context &) const { throw_ru_ex(); } virtual void write(serializer & s) const { s << *g_rewrite_unfold_opcode << m_info; } @@ -173,13 +218,11 @@ unfold_info const & get_rewrite_unfold_info(expr const & e) { return static_cast(macro_def(e).raw())->get_info(); } -class rewrite_element_macro_cell : public macro_definition_cell { +class rewrite_element_macro_cell : public rewrite_core_macro_cell { rewrite_info m_info; public: rewrite_element_macro_cell(rewrite_info const & info):m_info(info) {} virtual name get_name() const { return *g_rewrite_elem_name; } - virtual pair get_type(expr const &, extension_context &) const { throw_re_ex(); } - virtual optional expand(expr const &, extension_context &) const { throw_re_ex(); } virtual void write(serializer & s) const { s << *g_rewrite_elem_opcode << m_info; } @@ -376,8 +419,9 @@ class rewrite_fn { } [[ noreturn ]] void throw_max_iter_exceeded() { - throw_rewrite_exception(sstream() << "rewrite tactic failed, maximum number of iterations exceeded (current threshold: " - << m_max_iter << ", increase the threshold by setting option 'rewrite.max_iter')"); + throw_rewrite_exception(sstream() << "rewrite tactic failed, maximum number of iterations exceeded " + << "(current threshold: " << m_max_iter + << ", increase the threshold by setting option 'rewrite.max_iter')"); } void update_goal(goal const & g) { @@ -919,10 +963,23 @@ void initialize_rewrite_tactic() { register_unsigned_option(*g_rewriter_max_iterations, LEAN_DEFAULT_REWRITER_MAX_ITERATIONS, "(rewriter tactic) maximum number of iterations"); name rewrite_tac_name{"tactic", "rewrite_tac"}; g_rewrite_tac = new expr(Const(rewrite_tac_name)); + g_rewrite_reduce_name = new name("rewrite_reduce"); + g_rewrite_reduce_opcode = new std::string("RWR"); g_rewrite_unfold_name = new name("rewrite_unfold"); g_rewrite_unfold_opcode = new std::string("RWU"); g_rewrite_elem_name = new name("rewrite_element"); g_rewrite_elem_opcode = new std::string("RWE"); + register_macro_deserializer(*g_rewrite_reduce_opcode, + [](deserializer & d, unsigned num, expr const * args) { + if (num > 1) + throw corrupted_stream_exception(); + unfold_info info; + d >> info; + if (num == 0) + return mk_rewrite_reduce(info.get_location()); + else + return mk_rewrite_reduce_to(args[0], info.get_location()); + }); register_macro_deserializer(*g_rewrite_unfold_opcode, [](deserializer & d, unsigned num, expr const *) { if (num != 0) diff --git a/src/library/tactic/rewrite_tactic.h b/src/library/tactic/rewrite_tactic.h index 843e0c5da..1014f738d 100644 --- a/src/library/tactic/rewrite_tactic.h +++ b/src/library/tactic/rewrite_tactic.h @@ -10,6 +10,8 @@ Author: Leonardo de Moura namespace lean { expr mk_rewrite_unfold(name const & n, location const & loc); +expr mk_rewrite_reduce(location const & loc); +expr mk_rewrite_reduce_to(expr const & e, location const & loc); expr mk_rewrite_once(optional const & pattern, expr const & H, bool symm, location const & loc); expr mk_rewrite_zero_or_more(optional const & pattern, expr const & H, bool symm, location const & loc); expr mk_rewrite_one_or_more(optional const & pattern, expr const & H, bool symm, location const & loc);