feat(library/tactic/rewrite_tactic): add "reduction" step to rewrite tactic
This commit is contained in:
parent
808521223b
commit
7cdc88701d
2 changed files with 98 additions and 39 deletions
|
@ -49,19 +49,31 @@ public:
|
||||||
unfold_info(name const & n, location const & loc):m_name(n), m_location(loc) {}
|
unfold_info(name const & n, location const & loc):m_name(n), m_location(loc) {}
|
||||||
name const & get_name() const { return m_name; }
|
name const & get_name() const { return m_name; }
|
||||||
location const & get_location() const { return m_location; }
|
location const & get_location() const { return m_location; }
|
||||||
friend serializer & operator<<(serializer & s, unfold_info const & elem);
|
friend serializer & operator<<(serializer & s, unfold_info const & e) {
|
||||||
friend deserializer & operator>>(deserializer & d, unfold_info & e);
|
|
||||||
};
|
|
||||||
|
|
||||||
serializer & operator<<(serializer & s, unfold_info const & e) {
|
|
||||||
s << e.m_name << e.m_location;
|
s << e.m_name << e.m_location;
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
friend deserializer & operator>>(deserializer & d, unfold_info & e) {
|
||||||
deserializer & operator>>(deserializer & d, unfold_info & e) {
|
|
||||||
d >> e.m_name >> e.m_location;
|
d >> e.m_name >> e.m_location;
|
||||||
return d;
|
return d;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class reduce_info {
|
||||||
|
location m_location;
|
||||||
|
public:
|
||||||
|
reduce_info() {}
|
||||||
|
reduce_info(location const & loc):m_location(loc) {}
|
||||||
|
location const & get_location() const { return m_location; }
|
||||||
|
friend serializer & operator<<(serializer & s, reduce_info const & e) {
|
||||||
|
s << e.m_location;
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
friend deserializer & operator>>(deserializer & d, reduce_info & e) {
|
||||||
|
d >> e.m_location;
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class rewrite_info {
|
class rewrite_info {
|
||||||
public:
|
public:
|
||||||
|
@ -115,18 +127,14 @@ public:
|
||||||
|
|
||||||
location const & get_location() const { return m_location; }
|
location const & get_location() const { return m_location; }
|
||||||
|
|
||||||
friend serializer & operator<<(serializer & s, rewrite_info const & elem);
|
friend serializer & operator<<(serializer & s, rewrite_info const & e) {
|
||||||
friend deserializer & operator>>(deserializer & d, rewrite_info & e);
|
|
||||||
};
|
|
||||||
|
|
||||||
serializer & operator<<(serializer & s, rewrite_info const & e) {
|
|
||||||
s << e.m_symm << static_cast<char>(e.m_multiplicity) << e.m_location;
|
s << e.m_symm << static_cast<char>(e.m_multiplicity) << e.m_location;
|
||||||
if (e.has_num())
|
if (e.has_num())
|
||||||
s << e.num();
|
s << e.num();
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
deserializer & operator>>(deserializer & d, rewrite_info & e) {
|
friend deserializer & operator>>(deserializer & d, rewrite_info & e) {
|
||||||
char multp;
|
char multp;
|
||||||
d >> e.m_symm >> multp >> e.m_location;
|
d >> e.m_symm >> multp >> e.m_location;
|
||||||
e.m_multiplicity = static_cast<rewrite_info::multiplicity>(multp);
|
e.m_multiplicity = static_cast<rewrite_info::multiplicity>(multp);
|
||||||
|
@ -134,6 +142,7 @@ deserializer & operator>>(deserializer & d, rewrite_info & e) {
|
||||||
e.m_num = d.read_unsigned();
|
e.m_num = d.read_unsigned();
|
||||||
return d;
|
return d;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
static expr * g_rewrite_tac = nullptr;
|
static expr * g_rewrite_tac = nullptr;
|
||||||
|
|
||||||
|
@ -143,16 +152,52 @@ static std::string * g_rewrite_elem_opcode = nullptr;
|
||||||
static name * g_rewrite_unfold_name = nullptr;
|
static name * g_rewrite_unfold_name = nullptr;
|
||||||
static std::string * g_rewrite_unfold_opcode = nullptr;
|
static std::string * g_rewrite_unfold_opcode = nullptr;
|
||||||
|
|
||||||
[[ noreturn ]] static void throw_ru_ex() { throw exception("unexpected occurrence of 'rewrite unfold' expression"); }
|
static name * g_rewrite_reduce_name = nullptr;
|
||||||
[[ noreturn ]] static void throw_re_ex() { throw exception("unexpected occurrence of 'rewrite element' expression"); }
|
static std::string * g_rewrite_reduce_opcode = nullptr;
|
||||||
|
|
||||||
class rewrite_unfold_macro_cell : public macro_definition_cell {
|
[[ noreturn ]] static void throw_re_ex() { throw exception("unexpected occurrence of 'rewrite' expression"); }
|
||||||
|
|
||||||
|
class rewrite_core_macro_cell : public macro_definition_cell {
|
||||||
|
public:
|
||||||
|
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const { throw_re_ex(); }
|
||||||
|
virtual optional<expr> expand(expr const &, extension_context &) const { throw_re_ex(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
class rewrite_reduce_macro_cell : public rewrite_core_macro_cell {
|
||||||
|
reduce_info m_info;
|
||||||
|
public:
|
||||||
|
rewrite_reduce_macro_cell(reduce_info const & info):m_info(info) {}
|
||||||
|
virtual name get_name() const { return *g_rewrite_reduce_name; }
|
||||||
|
virtual void write(serializer & s) const {
|
||||||
|
s << *g_rewrite_reduce_opcode << m_info;
|
||||||
|
}
|
||||||
|
reduce_info const & get_info() const { return m_info; }
|
||||||
|
};
|
||||||
|
|
||||||
|
expr mk_rewrite_reduce(location const & loc) {
|
||||||
|
macro_definition def(new rewrite_reduce_macro_cell(reduce_info(loc)));
|
||||||
|
return mk_macro(def);
|
||||||
|
}
|
||||||
|
|
||||||
|
expr mk_rewrite_reduce_to(expr const & e, location const & loc) {
|
||||||
|
macro_definition def(new rewrite_reduce_macro_cell(reduce_info(loc)));
|
||||||
|
return mk_macro(def, 1, &e);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_rewrite_reduce_step(expr const & e) {
|
||||||
|
return is_macro(e) && macro_def(e).get_name() == *g_rewrite_reduce_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
reduce_info const & get_rewrite_reduce_info(expr const & e) {
|
||||||
|
lean_assert(is_rewrite_reduce_step(e));
|
||||||
|
return static_cast<rewrite_reduce_macro_cell const*>(macro_def(e).raw())->get_info();
|
||||||
|
}
|
||||||
|
|
||||||
|
class rewrite_unfold_macro_cell : public rewrite_core_macro_cell {
|
||||||
unfold_info m_info;
|
unfold_info m_info;
|
||||||
public:
|
public:
|
||||||
rewrite_unfold_macro_cell(unfold_info const & info):m_info(info) {}
|
rewrite_unfold_macro_cell(unfold_info const & info):m_info(info) {}
|
||||||
virtual name get_name() const { return *g_rewrite_unfold_name; }
|
virtual name get_name() const { return *g_rewrite_unfold_name; }
|
||||||
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const { throw_ru_ex(); }
|
|
||||||
virtual optional<expr> expand(expr const &, extension_context &) const { throw_ru_ex(); }
|
|
||||||
virtual void write(serializer & s) const {
|
virtual void write(serializer & s) const {
|
||||||
s << *g_rewrite_unfold_opcode << m_info;
|
s << *g_rewrite_unfold_opcode << m_info;
|
||||||
}
|
}
|
||||||
|
@ -173,13 +218,11 @@ unfold_info const & get_rewrite_unfold_info(expr const & e) {
|
||||||
return static_cast<rewrite_unfold_macro_cell const*>(macro_def(e).raw())->get_info();
|
return static_cast<rewrite_unfold_macro_cell const*>(macro_def(e).raw())->get_info();
|
||||||
}
|
}
|
||||||
|
|
||||||
class rewrite_element_macro_cell : public macro_definition_cell {
|
class rewrite_element_macro_cell : public rewrite_core_macro_cell {
|
||||||
rewrite_info m_info;
|
rewrite_info m_info;
|
||||||
public:
|
public:
|
||||||
rewrite_element_macro_cell(rewrite_info const & info):m_info(info) {}
|
rewrite_element_macro_cell(rewrite_info const & info):m_info(info) {}
|
||||||
virtual name get_name() const { return *g_rewrite_elem_name; }
|
virtual name get_name() const { return *g_rewrite_elem_name; }
|
||||||
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const { throw_re_ex(); }
|
|
||||||
virtual optional<expr> expand(expr const &, extension_context &) const { throw_re_ex(); }
|
|
||||||
virtual void write(serializer & s) const {
|
virtual void write(serializer & s) const {
|
||||||
s << *g_rewrite_elem_opcode << m_info;
|
s << *g_rewrite_elem_opcode << m_info;
|
||||||
}
|
}
|
||||||
|
@ -376,8 +419,9 @@ class rewrite_fn {
|
||||||
}
|
}
|
||||||
|
|
||||||
[[ noreturn ]] void throw_max_iter_exceeded() {
|
[[ noreturn ]] void throw_max_iter_exceeded() {
|
||||||
throw_rewrite_exception(sstream() << "rewrite tactic failed, maximum number of iterations exceeded (current threshold: "
|
throw_rewrite_exception(sstream() << "rewrite tactic failed, maximum number of iterations exceeded "
|
||||||
<< m_max_iter << ", increase the threshold by setting option 'rewrite.max_iter')");
|
<< "(current threshold: " << m_max_iter
|
||||||
|
<< ", increase the threshold by setting option 'rewrite.max_iter')");
|
||||||
}
|
}
|
||||||
|
|
||||||
void update_goal(goal const & g) {
|
void update_goal(goal const & g) {
|
||||||
|
@ -919,10 +963,23 @@ void initialize_rewrite_tactic() {
|
||||||
register_unsigned_option(*g_rewriter_max_iterations, LEAN_DEFAULT_REWRITER_MAX_ITERATIONS, "(rewriter tactic) maximum number of iterations");
|
register_unsigned_option(*g_rewriter_max_iterations, LEAN_DEFAULT_REWRITER_MAX_ITERATIONS, "(rewriter tactic) maximum number of iterations");
|
||||||
name rewrite_tac_name{"tactic", "rewrite_tac"};
|
name rewrite_tac_name{"tactic", "rewrite_tac"};
|
||||||
g_rewrite_tac = new expr(Const(rewrite_tac_name));
|
g_rewrite_tac = new expr(Const(rewrite_tac_name));
|
||||||
|
g_rewrite_reduce_name = new name("rewrite_reduce");
|
||||||
|
g_rewrite_reduce_opcode = new std::string("RWR");
|
||||||
g_rewrite_unfold_name = new name("rewrite_unfold");
|
g_rewrite_unfold_name = new name("rewrite_unfold");
|
||||||
g_rewrite_unfold_opcode = new std::string("RWU");
|
g_rewrite_unfold_opcode = new std::string("RWU");
|
||||||
g_rewrite_elem_name = new name("rewrite_element");
|
g_rewrite_elem_name = new name("rewrite_element");
|
||||||
g_rewrite_elem_opcode = new std::string("RWE");
|
g_rewrite_elem_opcode = new std::string("RWE");
|
||||||
|
register_macro_deserializer(*g_rewrite_reduce_opcode,
|
||||||
|
[](deserializer & d, unsigned num, expr const * args) {
|
||||||
|
if (num > 1)
|
||||||
|
throw corrupted_stream_exception();
|
||||||
|
unfold_info info;
|
||||||
|
d >> info;
|
||||||
|
if (num == 0)
|
||||||
|
return mk_rewrite_reduce(info.get_location());
|
||||||
|
else
|
||||||
|
return mk_rewrite_reduce_to(args[0], info.get_location());
|
||||||
|
});
|
||||||
register_macro_deserializer(*g_rewrite_unfold_opcode,
|
register_macro_deserializer(*g_rewrite_unfold_opcode,
|
||||||
[](deserializer & d, unsigned num, expr const *) {
|
[](deserializer & d, unsigned num, expr const *) {
|
||||||
if (num != 0)
|
if (num != 0)
|
||||||
|
|
|
@ -10,6 +10,8 @@ Author: Leonardo de Moura
|
||||||
|
|
||||||
namespace lean {
|
namespace lean {
|
||||||
expr mk_rewrite_unfold(name const & n, location const & loc);
|
expr mk_rewrite_unfold(name const & n, location const & loc);
|
||||||
|
expr mk_rewrite_reduce(location const & loc);
|
||||||
|
expr mk_rewrite_reduce_to(expr const & e, location const & loc);
|
||||||
expr mk_rewrite_once(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
|
expr mk_rewrite_once(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
|
||||||
expr mk_rewrite_zero_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
|
expr mk_rewrite_zero_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
|
||||||
expr mk_rewrite_one_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
|
expr mk_rewrite_one_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
|
||||||
|
|
Loading…
Reference in a new issue