feat(library/tactic/rewrite_tactic): add "reduction" step to rewrite tactic

This commit is contained in:
Leonardo de Moura 2015-02-05 13:16:05 -08:00
parent 808521223b
commit 7cdc88701d
2 changed files with 98 additions and 39 deletions

View file

@ -49,19 +49,31 @@ public:
unfold_info(name const & n, location const & loc):m_name(n), m_location(loc) {}
name const & get_name() const { return m_name; }
location const & get_location() const { return m_location; }
friend serializer & operator<<(serializer & s, unfold_info const & elem);
friend deserializer & operator>>(deserializer & d, unfold_info & e);
};
serializer & operator<<(serializer & s, unfold_info const & e) {
friend serializer & operator<<(serializer & s, unfold_info const & e) {
s << e.m_name << e.m_location;
return s;
}
deserializer & operator>>(deserializer & d, unfold_info & e) {
friend deserializer & operator>>(deserializer & d, unfold_info & e) {
d >> e.m_name >> e.m_location;
return d;
}
};
class reduce_info {
location m_location;
public:
reduce_info() {}
reduce_info(location const & loc):m_location(loc) {}
location const & get_location() const { return m_location; }
friend serializer & operator<<(serializer & s, reduce_info const & e) {
s << e.m_location;
return s;
}
friend deserializer & operator>>(deserializer & d, reduce_info & e) {
d >> e.m_location;
return d;
}
};
class rewrite_info {
public:
@ -115,18 +127,14 @@ public:
location const & get_location() const { return m_location; }
friend serializer & operator<<(serializer & s, rewrite_info const & elem);
friend deserializer & operator>>(deserializer & d, rewrite_info & e);
};
serializer & operator<<(serializer & s, rewrite_info const & e) {
friend serializer & operator<<(serializer & s, rewrite_info const & e) {
s << e.m_symm << static_cast<char>(e.m_multiplicity) << e.m_location;
if (e.has_num())
s << e.num();
return s;
}
deserializer & operator>>(deserializer & d, rewrite_info & e) {
friend deserializer & operator>>(deserializer & d, rewrite_info & e) {
char multp;
d >> e.m_symm >> multp >> e.m_location;
e.m_multiplicity = static_cast<rewrite_info::multiplicity>(multp);
@ -134,6 +142,7 @@ deserializer & operator>>(deserializer & d, rewrite_info & e) {
e.m_num = d.read_unsigned();
return d;
}
};
static expr * g_rewrite_tac = nullptr;
@ -143,16 +152,52 @@ static std::string * g_rewrite_elem_opcode = nullptr;
static name * g_rewrite_unfold_name = nullptr;
static std::string * g_rewrite_unfold_opcode = nullptr;
[[ noreturn ]] static void throw_ru_ex() { throw exception("unexpected occurrence of 'rewrite unfold' expression"); }
[[ noreturn ]] static void throw_re_ex() { throw exception("unexpected occurrence of 'rewrite element' expression"); }
static name * g_rewrite_reduce_name = nullptr;
static std::string * g_rewrite_reduce_opcode = nullptr;
class rewrite_unfold_macro_cell : public macro_definition_cell {
[[ noreturn ]] static void throw_re_ex() { throw exception("unexpected occurrence of 'rewrite' expression"); }
class rewrite_core_macro_cell : public macro_definition_cell {
public:
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const { throw_re_ex(); }
virtual optional<expr> expand(expr const &, extension_context &) const { throw_re_ex(); }
};
class rewrite_reduce_macro_cell : public rewrite_core_macro_cell {
reduce_info m_info;
public:
rewrite_reduce_macro_cell(reduce_info const & info):m_info(info) {}
virtual name get_name() const { return *g_rewrite_reduce_name; }
virtual void write(serializer & s) const {
s << *g_rewrite_reduce_opcode << m_info;
}
reduce_info const & get_info() const { return m_info; }
};
expr mk_rewrite_reduce(location const & loc) {
macro_definition def(new rewrite_reduce_macro_cell(reduce_info(loc)));
return mk_macro(def);
}
expr mk_rewrite_reduce_to(expr const & e, location const & loc) {
macro_definition def(new rewrite_reduce_macro_cell(reduce_info(loc)));
return mk_macro(def, 1, &e);
}
bool is_rewrite_reduce_step(expr const & e) {
return is_macro(e) && macro_def(e).get_name() == *g_rewrite_reduce_name;
}
reduce_info const & get_rewrite_reduce_info(expr const & e) {
lean_assert(is_rewrite_reduce_step(e));
return static_cast<rewrite_reduce_macro_cell const*>(macro_def(e).raw())->get_info();
}
class rewrite_unfold_macro_cell : public rewrite_core_macro_cell {
unfold_info m_info;
public:
rewrite_unfold_macro_cell(unfold_info const & info):m_info(info) {}
virtual name get_name() const { return *g_rewrite_unfold_name; }
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const { throw_ru_ex(); }
virtual optional<expr> expand(expr const &, extension_context &) const { throw_ru_ex(); }
virtual void write(serializer & s) const {
s << *g_rewrite_unfold_opcode << m_info;
}
@ -173,13 +218,11 @@ unfold_info const & get_rewrite_unfold_info(expr const & e) {
return static_cast<rewrite_unfold_macro_cell const*>(macro_def(e).raw())->get_info();
}
class rewrite_element_macro_cell : public macro_definition_cell {
class rewrite_element_macro_cell : public rewrite_core_macro_cell {
rewrite_info m_info;
public:
rewrite_element_macro_cell(rewrite_info const & info):m_info(info) {}
virtual name get_name() const { return *g_rewrite_elem_name; }
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const { throw_re_ex(); }
virtual optional<expr> expand(expr const &, extension_context &) const { throw_re_ex(); }
virtual void write(serializer & s) const {
s << *g_rewrite_elem_opcode << m_info;
}
@ -376,8 +419,9 @@ class rewrite_fn {
}
[[ noreturn ]] void throw_max_iter_exceeded() {
throw_rewrite_exception(sstream() << "rewrite tactic failed, maximum number of iterations exceeded (current threshold: "
<< m_max_iter << ", increase the threshold by setting option 'rewrite.max_iter')");
throw_rewrite_exception(sstream() << "rewrite tactic failed, maximum number of iterations exceeded "
<< "(current threshold: " << m_max_iter
<< ", increase the threshold by setting option 'rewrite.max_iter')");
}
void update_goal(goal const & g) {
@ -919,10 +963,23 @@ void initialize_rewrite_tactic() {
register_unsigned_option(*g_rewriter_max_iterations, LEAN_DEFAULT_REWRITER_MAX_ITERATIONS, "(rewriter tactic) maximum number of iterations");
name rewrite_tac_name{"tactic", "rewrite_tac"};
g_rewrite_tac = new expr(Const(rewrite_tac_name));
g_rewrite_reduce_name = new name("rewrite_reduce");
g_rewrite_reduce_opcode = new std::string("RWR");
g_rewrite_unfold_name = new name("rewrite_unfold");
g_rewrite_unfold_opcode = new std::string("RWU");
g_rewrite_elem_name = new name("rewrite_element");
g_rewrite_elem_opcode = new std::string("RWE");
register_macro_deserializer(*g_rewrite_reduce_opcode,
[](deserializer & d, unsigned num, expr const * args) {
if (num > 1)
throw corrupted_stream_exception();
unfold_info info;
d >> info;
if (num == 0)
return mk_rewrite_reduce(info.get_location());
else
return mk_rewrite_reduce_to(args[0], info.get_location());
});
register_macro_deserializer(*g_rewrite_unfold_opcode,
[](deserializer & d, unsigned num, expr const *) {
if (num != 0)

View file

@ -10,6 +10,8 @@ Author: Leonardo de Moura
namespace lean {
expr mk_rewrite_unfold(name const & n, location const & loc);
expr mk_rewrite_reduce(location const & loc);
expr mk_rewrite_reduce_to(expr const & e, location const & loc);
expr mk_rewrite_once(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
expr mk_rewrite_zero_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);
expr mk_rewrite_one_or_more(optional<expr> const & pattern, expr const & H, bool symm, location const & loc);