feat(library/tactic/rewrite_tactic): add "reduction" step to rewrite tactic
This commit is contained in:
parent
808521223b
commit
7cdc88701d
2 changed files with 98 additions and 39 deletions
|
@ -49,19 +49,31 @@ public:
|
|||
unfold_info(name const & n, location const & loc):m_name(n), m_location(loc) {}
|
||||
name const & get_name() const { return m_name; }
|
||||
location const & get_location() const { return m_location; }
|
||||
friend serializer & operator<<(serializer & s, unfold_info const & elem);
|
||||
friend deserializer & operator>>(deserializer & d, unfold_info & e);
|
||||
};
|
||||
|
||||
serializer & operator<<(serializer & s, unfold_info const & e) {
|
||||
friend serializer & operator<<(serializer & s, unfold_info const & e) {
|
||||
s << e.m_name << e.m_location;
|
||||
return s;
|
||||
}
|
||||
|
||||
deserializer & operator>>(deserializer & d, unfold_info & e) {
|
||||
friend deserializer & operator>>(deserializer & d, unfold_info & e) {
|
||||
d >> e.m_name >> e.m_location;
|
||||
return d;
|
||||
}
|
||||
};
|
||||
|
||||
class reduce_info {
|
||||
location m_location;
|
||||
public:
|
||||
reduce_info() {}
|
||||
reduce_info(location const & loc):m_location(loc) {}
|
||||
location const & get_location() const { return m_location; }
|
||||
friend serializer & operator<<(serializer & s, reduce_info const & e) {
|
||||
s << e.m_location;
|
||||
return s;
|
||||
}
|
||||
friend deserializer & operator>>(deserializer & d, reduce_info & e) {
|
||||
d >> e.m_location;
|
||||
return d;
|
||||
}
|
||||
};
|
||||
|
||||
class rewrite_info {
|
||||
public:
|
||||
|
@ -115,18 +127,14 @@ public:
|
|||
|
||||
location const & get_location() const { return m_location; }
|
||||
|
||||
friend serializer & operator<<(serializer & s, rewrite_info const & elem);
|
||||
friend deserializer & operator>>(deserializer & d, rewrite_info & e);
|
||||
};
|
||||
|
||||
serializer & operator<<(serializer & s, rewrite_info const & e) {
|
||||
friend serializer & operator<<(serializer & s, rewrite_info const & e) {
|
||||
s << e.m_symm << static_cast<char>(e.m_multiplicity) << e.m_location;
|
||||
if (e.has_num())
|
||||
s << e.num();
|
||||
return s;
|
||||
}
|
||||
|
||||
deserializer & operator>>(deserializer & d, rewrite_info & e) {
|
||||
friend deserializer & operator>>(deserializer & d, rewrite_info & e) {
|
||||
char multp;
|
||||
d >> e.m_symm >> multp >> e.m_location;
|
||||
e.m_multiplicity = static_cast<rewrite_info::multiplicity>(multp);
|
||||
|
@ -134,6 +142,7 @@ deserializer & operator>>(deserializer & d, rewrite_info & e) {
|
|||
e.m_num = d.read_unsigned();
|
||||
return d;
|
||||
}
|
||||
};
|
||||
|
||||
static expr * g_rewrite_tac = nullptr;
|
||||
|
||||
|
@ -143,16 +152,52 @@ static std::string * g_rewrite_elem_opcode = nullptr;
|
|||
static name * g_rewrite_unfold_name = nullptr;
|
||||
static std::string * g_rewrite_unfold_opcode = nullptr;
|
||||
|
||||
[[ noreturn ]] static void throw_ru_ex() { throw exception("unexpected occurrence of 'rewrite unfold' expression"); }
|
||||
[[ noreturn ]] static void throw_re_ex() { throw exception("unexpected occurrence of 'rewrite element' expression"); }
|
||||
static name * g_rewrite_reduce_name = nullptr;
|
||||
static std::string * g_rewrite_reduce_opcode = nullptr;
|
||||
|
||||
class rewrite_unfold_macro_cell : public macro_definition_cell {
|
||||
[[ noreturn ]] static void throw_re_ex() { throw exception("unexpected occurrence of 'rewrite' expression"); }
|
||||
|
||||
class rewrite_core_macro_cell : public macro_definition_cell {
|
||||
public:
|
||||
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const { throw_re_ex(); }
|
||||
virtual optional<expr> expand(expr const &, extension_context &) const { throw_re_ex(); }
|
||||
};
|
||||
|
||||
class rewrite_reduce_macro_cell : public rewrite_core_macro_cell {
|
||||
reduce_info m_info;
|
||||
public:
|
||||
rewrite_reduce_macro_cell(reduce_info const & info):m_info(info) {}
|
||||
virtual name get_name() const { return *g_rewrite_reduce_name; }
|
||||
virtual void write(serializer & s) const {
|
||||
s << *g_rewrite_reduce_opcode << m_info;
|
||||
}
|
||||
reduce_info const & get_info() const { return m_info; }
|
||||
};
|
||||
|
||||
expr mk_rewrite_reduce(location const & loc) {
|
||||
macro_definition def(new rewrite_reduce_macro_cell(reduce_info(loc)));
|
||||
return mk_macro(def);
|
||||
}
|
||||
|
||||
expr mk_rewrite_reduce_to(expr const & e, location const & loc) {
|
||||
macro_definition def(new rewrite_reduce_macro_cell(reduce_info(loc)));
|
||||
return mk_macro(def, 1, &e);
|
||||
}
|
||||
|
||||
bool is_rewrite_reduce_step(expr const & e) {
|
||||
return is_macro(e) && macro_def(e).get_name() == *g_rewrite_reduce_name;
|
||||
}
|
||||
|
||||
reduce_info const & get_rewrite_reduce_info(expr const & e) {
|
||||
lean_assert(is_rewrite_reduce_step(e));
|
||||
return static_cast<rewrite_reduce_macro_cell const*>(macro_def(e).raw())->get_info();
|
||||
}
|
||||
|
||||
class rewrite_unfold_macro_cell : public rewrite_core_macro_cell {
|
||||
unfold_info m_info;
|
||||
public:
|
||||
rewrite_unfold_macro_cell(unfold_info const & info):m_info(info) {}
|
||||
virtual name get_name() const { return *g_rewrite_unfold_name; }
|
||||
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const { throw_ru_ex(); }
|
||||
virtual optional<expr> expand(expr const &, extension_context &) const { throw_ru_ex(); }
|
||||
virtual void write(serializer & s) const {
|
||||
s << *g_rewrite_unfold_opcode << m_info;
|
||||
}
|
||||
|
@ -173,13 +218,11 @@ unfold_info const & get_rewrite_unfold_info(expr const & e) {
|
|||
return static_cast<rewrite_unfold_macro_cell const*>(macro_def(e).raw())->get_info();
|
||||
}
|
||||
|
||||
class rewrite_element_macro_cell : public macro_definition_cell {
|
||||
class rewrite_element_macro_cell : public rewrite_core_macro_cell {
|
||||
rewrite_info m_info;
|
||||
public:
|
||||
rewrite_element_macro_cell(rewrite_info const & info):m_info(info) {}
|
||||
virtual name get_name() const { return *g_rewrite_elem_name; }
|
||||
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const { throw_re_ex(); }
|
||||
virtual optional<expr> expand(expr const &, extension_context &) const { throw_re_ex(); }
|
||||
virtual void write(serializer & s) const {
|
||||
s << *g_rewrite_elem_opcode << m_info;
|
||||
}
|
||||
|
@ -376,8 +419,9 @@ class rewrite_fn {
|
|||
}
|
||||
|
||||
[[ noreturn ]] void throw_max_iter_exceeded() {
|
||||
throw_rewrite_exception(sstream() << "rewrite tactic failed, maximum number of iterations exceeded (current threshold: "
|
||||
<< m_max_iter << ", increase the threshold by setting option 'rewrite.max_iter')");
|
||||
throw_rewrite_exception(sstream() << "rewrite tactic failed, maximum number of iterations exceeded "
|
||||
<< "(current threshold: " << m_max_iter
|
||||
<< ", increase the threshold by setting option 'rewrite.max_iter')");
|
||||
}
|
||||
|
||||
void update_goal(goal const & g) {
|
||||
|
@ -919,10 +963,23 @@ void initialize_rewrite_tactic() {
|
|||
register_unsigned_option(*g_rewriter_max_iterations, LEAN_DEFAULT_REWRITER_MAX_ITERATIONS, "(rewriter tactic) maximum number of iterations");
|
||||
name rewrite_tac_name{"tactic", "rewrite_tac"};
|
||||
g_rewrite_tac = new expr(Const(rewrite_tac_name));
|
||||
g_rewrite_reduce_name = new name("rewrite_reduce");
|
||||
g_rewrite_reduce_opcode = new std::string("RWR");
|
||||
g_rewrite_unfold_name = new name("rewrite_unfold");
|
||||
g_rewrite_unfold_opcode = new std::string("RWU");
|
||||
g_rewrite_elem_name = new name("rewrite_element");
|
||||
g_rewrite_elem_opcode = new std::string("RWE");
|
||||
register_macro_deserializer(*g_rewrite_reduce_opcode,
|
||||
[](deserializer & d, unsigned num, expr const * args) {
|
||||
if (num > 1)
|
||||
throw corrupted_stream_exception();
|
||||
unfold_info info;
|
||||
d >> info;
|
||||
if (num == 0)
|
||||
return mk_rewrite_reduce(info.get_location());
|
||||
else
|
||||
return mk_rewrite_reduce_to(args[0], info.get_location());
|
||||
});
|
||||
register_macro_deserializer(*g_rewrite_unfold_opcode,
|
||||
[](deserializer & d, unsigned num, expr const *) {
|
||||
if (num != 0)
|
||||
|
|
|
@ -10,6 +10,8 @@ Author: Leonardo de Moura
|
|||
|
||||
namespace lean {
|
||||
expr mk_rewrite_unfold(name const & n, location const & loc);
|
||||
expr mk_rewrite_reduce(location const & loc);
|
||||
expr mk_rewrite_reduce_to(expr const & e, location const & loc);
|
||||
expr mk_rewrite_once(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
|
||||
expr mk_rewrite_zero_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
|
||||
expr mk_rewrite_one_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
|
||||
|
|
Loading…
Reference in a new issue