refactor(library/simplifier): rewriter_rule_set

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-01-18 17:23:41 -08:00
parent feea96e84d
commit 466285c577
4 changed files with 179 additions and 135 deletions

View file

@ -15,12 +15,7 @@ Author: Leonardo de Moura
#include "library/simplifier/rewrite_rule_set.h" #include "library/simplifier/rewrite_rule_set.h"
namespace lean { namespace lean {
/** struct rewrite_rule_set::rewrite_rule {
\brief Actual implementation of the \c rewrite_rule_set class.
*/
class rewrite_rule_set::imp {
typedef splay_tree<name, name_quick_cmp> name_set;
struct rewrite_rule {
name m_id; name m_id;
expr m_lhs; expr m_lhs;
expr m_ceq; expr m_ceq;
@ -29,14 +24,13 @@ class rewrite_rule_set::imp {
rewrite_rule(name const & id, expr const & lhs, expr const & ceq, expr const & proof, bool is_permutation): rewrite_rule(name const & id, expr const & lhs, expr const & ceq, expr const & proof, bool is_permutation):
m_id(id), m_lhs(lhs), m_ceq(ceq), m_proof(proof), m_is_permutation(is_permutation) {} m_id(id), m_lhs(lhs), m_ceq(ceq), m_proof(proof), m_is_permutation(is_permutation) {}
}; };
ro_environment::weak_ref m_env;
list<rewrite_rule> m_rule_set; // TODO(Leo): use better data-structure, e.g., discrimination trees
name_set m_disabled_rules;
public: rewrite_rule_set::rewrite_rule_set(ro_environment const & env):m_env(env.to_weak_ref()) {}
imp(ro_environment const & env):m_env(env.to_weak_ref()) {} rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & other):
m_env(other.m_env), m_rule_set(other.m_rule_set), m_disabled_rules(other.m_disabled_rules) {}
rewrite_rule_set::~rewrite_rule_set() {}
void insert(name const & id, expr const & th, expr const & proof) { void rewrite_rule_set::insert(name const & id, expr const & th, expr const & proof) {
ro_environment env(m_env); ro_environment env(m_env);
for (auto const & p : to_ceqs(env, th, proof)) { for (auto const & p : to_ceqs(env, th, proof)) {
expr const & ceq = p.first; expr const & ceq = p.first;
@ -52,7 +46,7 @@ public:
} }
} }
void insert(name const & th_name) { void rewrite_rule_set::insert(name const & th_name) {
ro_environment env(m_env); ro_environment env(m_env);
auto obj = env->find_object(th_name); auto obj = env->find_object(th_name);
if (obj && (obj->is_theorem() || obj->is_axiom())) { if (obj && (obj->is_theorem() || obj->is_axiom())) {
@ -62,22 +56,22 @@ public:
} }
} }
bool enabled(rewrite_rule const & rule) const { bool rewrite_rule_set::enabled(rewrite_rule const & rule) const {
return !m_disabled_rules.contains(rule.m_id); return !m_disabled_rules.contains(rule.m_id);
} }
bool enabled(name const & id) const { bool rewrite_rule_set::enabled(name const & id) const {
return !m_disabled_rules.contains(id); return !m_disabled_rules.contains(id);
} }
void enable(name const & id, bool f) { void rewrite_rule_set::enable(name const & id, bool f) {
if (f) if (f)
m_disabled_rules.erase(id); m_disabled_rules.erase(id);
else else
m_disabled_rules.insert(id); m_disabled_rules.insert(id);
} }
void for_each_match_candidate(expr const &, match_fn const & fn) const { void rewrite_rule_set::for_each_match_candidate(expr const &, match_fn const & fn) const {
auto l = m_rule_set; auto l = m_rule_set;
for (auto const & rule : l) { for (auto const & rule : l) {
if (enabled(rule) && fn(rule.m_lhs, rule.m_ceq, rule.m_is_permutation, rule.m_proof)) if (enabled(rule) && fn(rule.m_lhs, rule.m_ceq, rule.m_is_permutation, rule.m_proof))
@ -85,14 +79,14 @@ public:
} }
} }
void for_each(visit_fn const & fn) const { void rewrite_rule_set::for_each(visit_fn const & fn) const {
auto l = m_rule_set; auto l = m_rule_set;
for (auto const & rule : l) { for (auto const & rule : l) {
fn(rule.m_id, rule.m_ceq, rule.m_proof, enabled(rule)); fn(rule.m_id, rule.m_ceq, rule.m_proof, enabled(rule));
} }
} }
format pp(formatter const & fmt, options const & opts) const { format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const {
format r; format r;
bool first = true; bool first = true;
unsigned indent = get_pp_indent(opts); unsigned indent = get_pp_indent(opts);
@ -109,20 +103,6 @@ public:
}); });
return r; return r;
} }
};
rewrite_rule_set::rewrite_rule_set(imp * ptr):m_ptr(ptr) {}
rewrite_rule_set::rewrite_rule_set(ro_environment const & env):m_ptr(new imp(env)) {}
rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & rs):m_ptr(new imp(*rs.m_ptr)) {}
rewrite_rule_set::~rewrite_rule_set() {}
void rewrite_rule_set::insert(name const & id, expr const & th, expr const & proof) { m_ptr->insert(id, th, proof); }
void rewrite_rule_set::insert(name const & th_name) { m_ptr->insert(th_name); }
bool rewrite_rule_set::enabled(name const & id) const { return m_ptr->enabled(id); }
void rewrite_rule_set::enable(name const & id, bool f) { m_ptr->enable(id, f); }
void rewrite_rule_set::for_each_match_candidate(expr const & e, match_fn const & fn) const {
m_ptr->for_each_match_candidate(e, fn);
}
void rewrite_rule_set::for_each(visit_fn const & fn) const { m_ptr->for_each(fn); }
format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const { return m_ptr->pp(fmt, opts); }
class mk_rewrite_rule_set_obj : public neutral_object_cell { class mk_rewrite_rule_set_obj : public neutral_object_cell {
name m_rule_set_id; name m_rule_set_id;
@ -234,7 +214,7 @@ struct rewrite_rule_set_extension : public environment_extension {
env->add_neutral_object(new enable_rewrite_rules_obj(rule_set_id, rule_id, flag)); env->add_neutral_object(new enable_rewrite_rules_obj(rule_set_id, rule_id, flag));
} }
rewrite_rule_set const & get_rewrite_rule_set(name const & rule_set_id) const { rewrite_rule_set get_rewrite_rule_set(name const & rule_set_id) const {
return find_ro(rule_set_id); return find_ro(rule_set_id);
} }
}; };
@ -258,6 +238,11 @@ static rewrite_rule_set_extension & to_ext(environment const & env) {
return env->get_extension<rewrite_rule_set_extension>(g_rewrite_rule_set_extension_initializer.m_extid); return env->get_extension<rewrite_rule_set_extension>(g_rewrite_rule_set_extension_initializer.m_extid);
} }
static name g_default_rewrite_rule_set_id("default");
name const & get_default_rewrite_rule_set_id() {
return g_default_rewrite_rule_set_id;
}
void mk_rewrite_rule_set(environment const & env, name const & rule_set_id) { void mk_rewrite_rule_set(environment const & env, name const & rule_set_id) {
to_ext(env).mk_rewrite_rule_set(env, rule_set_id); to_ext(env).mk_rewrite_rule_set(env, rule_set_id);
} }
@ -270,13 +255,15 @@ void enable_rewrite_rules(environment const & env, name const & rule_set_id, nam
to_ext(env).enable_rewrite_rules(env, rule_set_id, rule_id, flag); to_ext(env).enable_rewrite_rules(env, rule_set_id, rule_id, flag);
} }
rewrite_rule_set const & get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id) { rewrite_rule_set get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id) {
return to_ext(env).get_rewrite_rule_set(rule_set_id); return to_ext(env).get_rewrite_rule_set(rule_set_id);
} }
static int mk_rewrite_rule_set(lua_State * L) { static int mk_rewrite_rule_set(lua_State * L) {
int nargs = lua_gettop(L); int nargs = lua_gettop(L);
if (nargs == 1) if (nargs == 0)
mk_rewrite_rule_set(rw_shared_environment(L));
else if (nargs == 1)
mk_rewrite_rule_set(rw_shared_environment(L), to_name_ext(L, 1)); mk_rewrite_rule_set(rw_shared_environment(L), to_name_ext(L, 1));
else else
mk_rewrite_rule_set(rw_shared_environment(L, 2), to_name_ext(L, 1)); mk_rewrite_rule_set(rw_shared_environment(L, 2), to_name_ext(L, 1));
@ -285,7 +272,9 @@ static int mk_rewrite_rule_set(lua_State * L) {
static int add_rewrite_rules(lua_State * L) { static int add_rewrite_rules(lua_State * L) {
int nargs = lua_gettop(L); int nargs = lua_gettop(L);
if (nargs == 2) if (nargs == 1)
add_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1));
else if (nargs == 2)
add_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), to_name_ext(L, 2)); add_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), to_name_ext(L, 2));
else else
add_rewrite_rules(rw_shared_environment(L, 3), to_name_ext(L, 1), to_name_ext(L, 2)); add_rewrite_rules(rw_shared_environment(L, 3), to_name_ext(L, 1), to_name_ext(L, 2));
@ -294,7 +283,9 @@ static int add_rewrite_rules(lua_State * L) {
static int enable_rewrite_rules(lua_State * L) { static int enable_rewrite_rules(lua_State * L) {
int nargs = lua_gettop(L); int nargs = lua_gettop(L);
if (nargs == 3) if (nargs == 2)
enable_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), lua_toboolean(L, 2));
else if (nargs == 3)
enable_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), to_name_ext(L, 2), lua_toboolean(L, 3)); enable_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), to_name_ext(L, 2), lua_toboolean(L, 3));
else else
enable_rewrite_rules(rw_shared_environment(L, 4), to_name_ext(L, 1), to_name_ext(L, 2), lua_toboolean(L, 3)); enable_rewrite_rules(rw_shared_environment(L, 4), to_name_ext(L, 1), to_name_ext(L, 2), lua_toboolean(L, 3));
@ -306,7 +297,9 @@ static int show_rewrite_rules(lua_State * L) {
formatter fmt = get_global_formatter(L); formatter fmt = get_global_formatter(L);
options opts = get_global_options(L); options opts = get_global_options(L);
format r; format r;
if (nargs == 1) if (nargs == 0)
r = get_rewrite_rule_set(ro_shared_environment(L)).pp(fmt, opts);
else if (nargs == 1)
r = get_rewrite_rule_set(ro_shared_environment(L), to_name_ext(L, 1)).pp(fmt, opts); r = get_rewrite_rule_set(ro_shared_environment(L), to_name_ext(L, 1)).pp(fmt, opts);
else else
r = get_rewrite_rule_set(ro_shared_environment(L, 2), to_name_ext(L, 1)).pp(fmt, opts); r = get_rewrite_rule_set(ro_shared_environment(L, 2), to_name_ext(L, 1)).pp(fmt, opts);

View file

@ -8,23 +8,35 @@ Author: Leonardo de Moura
#include <memory> #include <memory>
#include <functional> #include <functional>
#include "util/lua.h" #include "util/lua.h"
#include "util/list.h"
#include "util/splay_tree.h"
#include "util/name.h"
#include "kernel/environment.h" #include "kernel/environment.h"
#include "kernel/formatter.h" #include "kernel/formatter.h"
namespace lean { namespace lean {
/**
\brief Actual implementation of the \c rewrite_rule_set class.
*/
class rewrite_rule_set { class rewrite_rule_set {
class imp; struct rewrite_rule;
std::unique_ptr<imp> m_ptr; typedef splay_tree<name, name_quick_cmp> name_set;
rewrite_rule_set(imp * ptr); ro_environment::weak_ref m_env;
list<rewrite_rule> m_rule_set; // TODO(Leo): use better data-structure, e.g., discrimination trees
name_set m_disabled_rules;
bool enabled(rewrite_rule const & rule) const;
public: public:
rewrite_rule_set(ro_environment const & env); rewrite_rule_set(ro_environment const & env);
rewrite_rule_set(rewrite_rule_set const & rs); rewrite_rule_set(rewrite_rule_set const & other);
~rewrite_rule_set(); ~rewrite_rule_set();
/** /**
\brief Convert the theorem \c th with proof \c proof into conditional rewrite rules, and \brief Convert the theorem \c th with proof \c proof into conditional rewrite rules, and
insert the rules into this rule set. The new rules are tagged with the given \c id. insert the rules into this rule set. The new rules are tagged with the given \c id.
*/ */
void insert(name const & id, expr const & th, expr const & proof); void insert(name const & id, expr const & th, expr const & proof);
/** /**
\brief Convert the theorem/axiom named \c th_name in the environment into conditional rewrite rules, \brief Convert the theorem/axiom named \c th_name in the environment into conditional rewrite rules,
and insert the rules into this rule set. The new rules are tagged with the theorem name. and insert the rules into this rule set. The new rules are tagged with the theorem name.
@ -33,11 +45,12 @@ public:
*/ */
void insert(name const & th_name); void insert(name const & th_name);
/** \brief Enable/disable the conditional rewrite rules tagged with the given identifier. */
void enable(name const & id, bool flag);
/** \brief Return true iff the conditional rewrite rules tagged with the given id are enabled. */ /** \brief Return true iff the conditional rewrite rules tagged with the given id are enabled. */
bool enabled(name const & id) const; bool enabled(name const & id) const;
/** \brief Enable/disable the conditional rewrite rules tagged with the given identifier. */
void enable(name const & id, bool f);
typedef std::function<bool(expr const &, expr const &, bool is_permutation, expr const &)> match_fn; // NOLINT typedef std::function<bool(expr const &, expr const &, bool is_permutation, expr const &)> match_fn; // NOLINT
typedef std::function<void(name const &, expr const &, expr const &, bool)> visit_fn; typedef std::function<void(name const &, expr const &, expr const &, bool)> visit_fn;
@ -51,7 +64,8 @@ public:
The argument \c proof is the proof for \c ceq. The argument \c proof is the proof for \c ceq.
*/ */
void for_each_match_candidate(expr const & e, match_fn const & fn) const; void for_each_match_candidate(expr const &, match_fn const & fn) const;
/** \brief Execute <tt>fn(id, ceq, proof, enabled)</tt> for each rule in this rule set. */ /** \brief Execute <tt>fn(id, ceq, proof, enabled)</tt> for each rule in this rule set. */
void for_each(visit_fn const & fn) const; void for_each(visit_fn const & fn) const;
@ -59,13 +73,14 @@ public:
format pp(formatter const & fmt, options const & opts) const; format pp(formatter const & fmt, options const & opts) const;
}; };
name const & get_default_rewrite_rule_set_id();
/** /**
\brief Create a rewrite rule set inside the given environment. \brief Create a rewrite rule set inside the given environment.
\remark The rule set is saved when the environment is serialized. \remark The rule set is saved when the environment is serialized.
\remark This procedure throws an exception if the environment already contains a rule set named \c rule_set_id. \remark This procedure throws an exception if the environment already contains a rule set named \c rule_set_id.
*/ */
void mk_rewrite_rule_set(environment const & env, name const & rule_set_id); void mk_rewrite_rule_set(environment const & env, name const & rule_set_id = get_default_rewrite_rule_set_id());
/** /**
\brief Convert the theorem named \c th_name into conditional rewrite rules \brief Convert the theorem named \c th_name into conditional rewrite rules
and insert them in the rule set named \c rule_set_id in the given environment. and insert them in the rule set named \c rule_set_id in the given environment.
@ -74,18 +89,25 @@ void mk_rewrite_rule_set(environment const & env, name const & rule_set_id);
\remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id. \remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id.
*/ */
void add_rewrite_rules(environment const & env, name const & rule_set_id, name const & th_name); void add_rewrite_rules(environment const & env, name const & rule_set_id, name const & th_name);
inline void add_rewrite_rules(environment const & env, name const & th_name) {
add_rewrite_rules(env, get_default_rewrite_rule_set_id(), th_name);
}
/** /**
\brief Enable/disable the rewrite rules whose id is \c rule_id in the given rule set. \brief Enable/disable the rewrite rules whose id is \c rule_id in the given rule set.
\remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id. \remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id.
*/ */
void enable_rewrite_rules(environment const & env, name const & rule_set_id, name const & rule_id, bool flag); void enable_rewrite_rules(environment const & env, name const & rule_set_id, name const & rule_id, bool flag);
inline void enable_rewrite_rules(environment const & env, name const & rule_id, bool flag) {
enable_rewrite_rules(env, get_default_rewrite_rule_set_id(), rule_id, flag);
}
/** /**
\brief Return the rule set name \c rule_set_id in the given environment. \brief Return the rule set name \c rule_set_id in the given environment.
\remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id. \remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id.
*/ */
rewrite_rule_set const & get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id); rewrite_rule_set get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id = get_default_rewrite_rule_set_id());
void open_rewrite_rule_set(lua_State * L); void open_rewrite_rule_set(lua_State * L);
} }

View file

@ -5,6 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura Author: Leonardo de Moura
*/ */
#include <utility> #include <utility>
#include <vector>
#include "util/flet.h" #include "util/flet.h"
#include "util/interrupt.h" #include "util/interrupt.h"
#include "kernel/type_checker.h" #include "kernel/type_checker.h"
@ -15,6 +16,7 @@ Author: Leonardo de Moura
#include "library/kernel_bindings.h" #include "library/kernel_bindings.h"
#include "library/expr_pair.h" #include "library/expr_pair.h"
#include "library/hop_match.h" #include "library/hop_match.h"
#include "library/simplifier/rewrite_rule_set.h"
#ifndef LEAN_SIMPLIFIER_PROOFS #ifndef LEAN_SIMPLIFIER_PROOFS
#define LEAN_SIMPLIFIER_PROOFS true #define LEAN_SIMPLIFIER_PROOFS true
@ -75,10 +77,12 @@ unsigned get_simplifier_max_steps(options const & opts) {
} }
class simplifier_fn { class simplifier_fn {
typedef std::vector<rewrite_rule_set> rule_sets;
ro_environment m_env; ro_environment m_env;
type_checker m_tc; type_checker m_tc;
bool m_has_heq; bool m_has_heq;
context m_ctx; context m_ctx;
rule_sets m_rule_sets;
// Configuration // Configuration
bool m_proofs_enabled; bool m_proofs_enabled;
@ -385,8 +389,8 @@ class simplifier_fn {
} }
public: public:
simplifier_fn(ro_environment const & env, options const & o): simplifier_fn(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs):
m_env(env), m_tc(env) { m_env(env), m_tc(env), m_rule_sets(rs, rs + num_rs) {
m_has_heq = m_env->imported("heq"); m_has_heq = m_env->imported("heq");
set_options(o); set_options(o);
} }
@ -402,19 +406,42 @@ public:
} }
}; };
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts) { expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts,
return simplifier_fn(env, opts)(e, ctx); unsigned num_rs, rewrite_rule_set const * rs) {
return simplifier_fn(env, opts, num_rs, rs)(e, ctx);
} }
static int simplify_core(lua_State * L, expr const & e, ro_shared_environment const & env) { expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts,
unsigned num_ns, name const * ns) {
buffer<rewrite_rule_set> rules;
for (unsigned i = 0; i < num_ns; i++)
rules.push_back(get_rewrite_rule_set(env, ns[i]));
return simplify(e, env, ctx, opts, num_ns, rules.data());
}
static int simplify_core(lua_State * L, ro_shared_environment const & env) {
int nargs = lua_gettop(L); int nargs = lua_gettop(L);
expr const & e = to_expr(L, 1);
buffer<rewrite_rule_set> rules;
if (nargs == 1) {
rules.push_back(get_rewrite_rule_set(env));
} else {
luaL_checktype(L, 2, LUA_TTABLE);
name r;
int n = objlen(L, 2);
for (int i = 1; i <= n; i++) {
lua_rawgeti(L, 2, i);
rules.push_back(get_rewrite_rule_set(env, to_name_ext(L, -1)));
lua_pop(L, 1);
}
}
context ctx; context ctx;
options opts; options opts;
if (nargs >= 3)
ctx = to_context(L, 3);
if (nargs >= 4) if (nargs >= 4)
opts = to_options(L, 4); ctx = to_context(L, 4);
auto r = simplify(e, env, ctx, opts); if (nargs >= 5)
opts = to_options(L, 5);
auto r = simplify(e, env, ctx, opts, rules.size(), rules.data());
push_expr(L, r.first); push_expr(L, r.first);
push_expr(L, r.second); push_expr(L, r.second);
return 2; return 2;
@ -422,11 +449,10 @@ static int simplify_core(lua_State * L, expr const & e, ro_shared_environment co
static int simplify(lua_State * L) { static int simplify(lua_State * L) {
int nargs = lua_gettop(L); int nargs = lua_gettop(L);
expr const & e = to_expr(L, 1); if (nargs <= 2)
if (nargs == 1) return simplify_core(L, ro_shared_environment(L));
return simplify_core(L, e, ro_shared_environment(L));
else else
return simplify_core(L, e, ro_shared_environment(L, 2)); return simplify_core(L, ro_shared_environment(L, 3));
} }
void open_simplifier(lua_State * L) { void open_simplifier(lua_State * L) {

View file

@ -10,6 +10,9 @@ Author: Leonardo de Moura
#include "library/expr_pair.h" #include "library/expr_pair.h"
namespace lean { namespace lean {
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts); expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & pts,
unsigned num_rs, rewrite_rule_set const * rs);
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts,
unsigned num_ns, name const * ns);
void open_simplifier(lua_State * L); void open_simplifier(lua_State * L);
} }