feat(library/tactic/rewrite_tactic): add "reduction" step to rewrite tactic

This commit is contained in:
Leonardo de Moura 2015-02-05 13:16:05 -08:00
parent 808521223b
commit 7cdc88701d
2 changed files with 98 additions and 39 deletions

View file

@ -49,19 +49,31 @@ public:
unfold_info(name const & n, location const & loc):m_name(n), m_location(loc) {} unfold_info(name const & n, location const & loc):m_name(n), m_location(loc) {}
name const & get_name() const { return m_name; } name const & get_name() const { return m_name; }
location const & get_location() const { return m_location; } location const & get_location() const { return m_location; }
friend serializer & operator<<(serializer & s, unfold_info const & elem); friend serializer & operator<<(serializer & s, unfold_info const & e) {
friend deserializer & operator>>(deserializer & d, unfold_info & e);
};
serializer & operator<<(serializer & s, unfold_info const & e) {
s << e.m_name << e.m_location; s << e.m_name << e.m_location;
return s; return s;
} }
friend deserializer & operator>>(deserializer & d, unfold_info & e) {
deserializer & operator>>(deserializer & d, unfold_info & e) {
d >> e.m_name >> e.m_location; d >> e.m_name >> e.m_location;
return d; return d;
} }
};
class reduce_info {
location m_location;
public:
reduce_info() {}
reduce_info(location const & loc):m_location(loc) {}
location const & get_location() const { return m_location; }
friend serializer & operator<<(serializer & s, reduce_info const & e) {
s << e.m_location;
return s;
}
friend deserializer & operator>>(deserializer & d, reduce_info & e) {
d >> e.m_location;
return d;
}
};
class rewrite_info { class rewrite_info {
public: public:
@ -115,25 +127,22 @@ public:
location const & get_location() const { return m_location; } location const & get_location() const { return m_location; }
friend serializer & operator<<(serializer & s, rewrite_info const & elem); friend serializer & operator<<(serializer & s, rewrite_info const & e) {
friend deserializer & operator>>(deserializer & d, rewrite_info & e);
};
serializer & operator<<(serializer & s, rewrite_info const & e) {
s << e.m_symm << static_cast<char>(e.m_multiplicity) << e.m_location; s << e.m_symm << static_cast<char>(e.m_multiplicity) << e.m_location;
if (e.has_num()) if (e.has_num())
s << e.num(); s << e.num();
return s; return s;
} }
deserializer & operator>>(deserializer & d, rewrite_info & e) { friend deserializer & operator>>(deserializer & d, rewrite_info & e) {
char multp; char multp;
d >> e.m_symm >> multp >> e.m_location; d >> e.m_symm >> multp >> e.m_location;
e.m_multiplicity = static_cast<rewrite_info::multiplicity>(multp); e.m_multiplicity = static_cast<rewrite_info::multiplicity>(multp);
if (e.has_num()) if (e.has_num())
e.m_num = d.read_unsigned(); e.m_num = d.read_unsigned();
return d; return d;
} }
};
static expr * g_rewrite_tac = nullptr; static expr * g_rewrite_tac = nullptr;
@ -143,16 +152,52 @@ static std::string * g_rewrite_elem_opcode = nullptr;
static name * g_rewrite_unfold_name = nullptr; static name * g_rewrite_unfold_name = nullptr;
static std::string * g_rewrite_unfold_opcode = nullptr; static std::string * g_rewrite_unfold_opcode = nullptr;
[[ noreturn ]] static void throw_ru_ex() { throw exception("unexpected occurrence of 'rewrite unfold' expression"); } static name * g_rewrite_reduce_name = nullptr;
[[ noreturn ]] static void throw_re_ex() { throw exception("unexpected occurrence of 'rewrite element' expression"); } static std::string * g_rewrite_reduce_opcode = nullptr;
class rewrite_unfold_macro_cell : public macro_definition_cell { [[ noreturn ]] static void throw_re_ex() { throw exception("unexpected occurrence of 'rewrite' expression"); }
class rewrite_core_macro_cell : public macro_definition_cell {
public:
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const { throw_re_ex(); }
virtual optional<expr> expand(expr const &, extension_context &) const { throw_re_ex(); }
};
class rewrite_reduce_macro_cell : public rewrite_core_macro_cell {
reduce_info m_info;
public:
rewrite_reduce_macro_cell(reduce_info const & info):m_info(info) {}
virtual name get_name() const { return *g_rewrite_reduce_name; }
virtual void write(serializer & s) const {
s << *g_rewrite_reduce_opcode << m_info;
}
reduce_info const & get_info() const { return m_info; }
};
expr mk_rewrite_reduce(location const & loc) {
macro_definition def(new rewrite_reduce_macro_cell(reduce_info(loc)));
return mk_macro(def);
}
expr mk_rewrite_reduce_to(expr const & e, location const & loc) {
macro_definition def(new rewrite_reduce_macro_cell(reduce_info(loc)));
return mk_macro(def, 1, &e);
}
bool is_rewrite_reduce_step(expr const & e) {
return is_macro(e) && macro_def(e).get_name() == *g_rewrite_reduce_name;
}
reduce_info const & get_rewrite_reduce_info(expr const & e) {
lean_assert(is_rewrite_reduce_step(e));
return static_cast<rewrite_reduce_macro_cell const*>(macro_def(e).raw())->get_info();
}
class rewrite_unfold_macro_cell : public rewrite_core_macro_cell {
unfold_info m_info; unfold_info m_info;
public: public:
rewrite_unfold_macro_cell(unfold_info const & info):m_info(info) {} rewrite_unfold_macro_cell(unfold_info const & info):m_info(info) {}
virtual name get_name() const { return *g_rewrite_unfold_name; } virtual name get_name() const { return *g_rewrite_unfold_name; }
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const { throw_ru_ex(); }
virtual optional<expr> expand(expr const &, extension_context &) const { throw_ru_ex(); }
virtual void write(serializer & s) const { virtual void write(serializer & s) const {
s << *g_rewrite_unfold_opcode << m_info; s << *g_rewrite_unfold_opcode << m_info;
} }
@ -173,13 +218,11 @@ unfold_info const & get_rewrite_unfold_info(expr const & e) {
return static_cast<rewrite_unfold_macro_cell const*>(macro_def(e).raw())->get_info(); return static_cast<rewrite_unfold_macro_cell const*>(macro_def(e).raw())->get_info();
} }
class rewrite_element_macro_cell : public macro_definition_cell { class rewrite_element_macro_cell : public rewrite_core_macro_cell {
rewrite_info m_info; rewrite_info m_info;
public: public:
rewrite_element_macro_cell(rewrite_info const & info):m_info(info) {} rewrite_element_macro_cell(rewrite_info const & info):m_info(info) {}
virtual name get_name() const { return *g_rewrite_elem_name; } virtual name get_name() const { return *g_rewrite_elem_name; }
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const { throw_re_ex(); }
virtual optional<expr> expand(expr const &, extension_context &) const { throw_re_ex(); }
virtual void write(serializer & s) const { virtual void write(serializer & s) const {
s << *g_rewrite_elem_opcode << m_info; s << *g_rewrite_elem_opcode << m_info;
} }
@ -376,8 +419,9 @@ class rewrite_fn {
} }
[[ noreturn ]] void throw_max_iter_exceeded() { [[ noreturn ]] void throw_max_iter_exceeded() {
throw_rewrite_exception(sstream() << "rewrite tactic failed, maximum number of iterations exceeded (current threshold: " throw_rewrite_exception(sstream() << "rewrite tactic failed, maximum number of iterations exceeded "
<< m_max_iter << ", increase the threshold by setting option 'rewrite.max_iter')"); << "(current threshold: " << m_max_iter
<< ", increase the threshold by setting option 'rewrite.max_iter')");
} }
void update_goal(goal const & g) { void update_goal(goal const & g) {
@ -919,10 +963,23 @@ void initialize_rewrite_tactic() {
register_unsigned_option(*g_rewriter_max_iterations, LEAN_DEFAULT_REWRITER_MAX_ITERATIONS, "(rewriter tactic) maximum number of iterations"); register_unsigned_option(*g_rewriter_max_iterations, LEAN_DEFAULT_REWRITER_MAX_ITERATIONS, "(rewriter tactic) maximum number of iterations");
name rewrite_tac_name{"tactic", "rewrite_tac"}; name rewrite_tac_name{"tactic", "rewrite_tac"};
g_rewrite_tac = new expr(Const(rewrite_tac_name)); g_rewrite_tac = new expr(Const(rewrite_tac_name));
g_rewrite_reduce_name = new name("rewrite_reduce");
g_rewrite_reduce_opcode = new std::string("RWR");
g_rewrite_unfold_name = new name("rewrite_unfold"); g_rewrite_unfold_name = new name("rewrite_unfold");
g_rewrite_unfold_opcode = new std::string("RWU"); g_rewrite_unfold_opcode = new std::string("RWU");
g_rewrite_elem_name = new name("rewrite_element"); g_rewrite_elem_name = new name("rewrite_element");
g_rewrite_elem_opcode = new std::string("RWE"); g_rewrite_elem_opcode = new std::string("RWE");
register_macro_deserializer(*g_rewrite_reduce_opcode,
[](deserializer & d, unsigned num, expr const * args) {
if (num > 1)
throw corrupted_stream_exception();
unfold_info info;
d >> info;
if (num == 0)
return mk_rewrite_reduce(info.get_location());
else
return mk_rewrite_reduce_to(args[0], info.get_location());
});
register_macro_deserializer(*g_rewrite_unfold_opcode, register_macro_deserializer(*g_rewrite_unfold_opcode,
[](deserializer & d, unsigned num, expr const *) { [](deserializer & d, unsigned num, expr const *) {
if (num != 0) if (num != 0)

View file

@ -10,6 +10,8 @@ Author: Leonardo de Moura
namespace lean { namespace lean {
expr mk_rewrite_unfold(name const & n, location const & loc); expr mk_rewrite_unfold(name const & n, location const & loc);
expr mk_rewrite_reduce(location const & loc);
expr mk_rewrite_reduce_to(expr const & e, location const & loc);
expr mk_rewrite_once(optional<expr> const & pattern, expr const & H, bool symm, location const & loc); expr mk_rewrite_once(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
expr mk_rewrite_zero_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc); expr mk_rewrite_zero_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
expr mk_rewrite_one_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc); expr mk_rewrite_one_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);