refactor(library/simplifier): rewriter_rule_set

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-01-18 17:23:41 -08:00
parent feea96e84d
commit 466285c577
4 changed files with 179 additions and 135 deletions

View file

@ -15,114 +15,94 @@ Author: Leonardo de Moura
#include "library/simplifier/rewrite_rule_set.h"
namespace lean {
/**
\brief Actual implementation of the \c rewrite_rule_set class.
*/
class rewrite_rule_set::imp {
typedef splay_tree<name, name_quick_cmp> name_set;
struct rewrite_rule {
name m_id;
expr m_lhs;
expr m_ceq;
expr m_proof;
bool m_is_permutation;
rewrite_rule(name const & id, expr const & lhs, expr const & ceq, expr const & proof, bool is_permutation):
m_id(id), m_lhs(lhs), m_ceq(ceq), m_proof(proof), m_is_permutation(is_permutation) {}
};
ro_environment::weak_ref m_env;
list<rewrite_rule> m_rule_set; // TODO(Leo): use better data-structure, e.g., discrimination trees
name_set m_disabled_rules;
public:
imp(ro_environment const & env):m_env(env.to_weak_ref()) {}
void insert(name const & id, expr const & th, expr const & proof) {
ro_environment env(m_env);
for (auto const & p : to_ceqs(env, th, proof)) {
expr const & ceq = p.first;
expr const & proof = p.second;
bool is_perm = is_permutation_ceq(ceq);
expr lhs = ceq;
while (is_pi(lhs)) {
lhs = abst_body(lhs);
}
lean_assert(is_equality(lhs));
lhs = arg(lhs, num_args(lhs) - 2);
m_rule_set.emplace_front(id, lhs, ceq, proof, is_perm);
}
}
void insert(name const & th_name) {
ro_environment env(m_env);
auto obj = env->find_object(th_name);
if (obj && (obj->is_theorem() || obj->is_axiom())) {
insert(th_name, obj->get_type(), mk_constant(th_name));
} else {
throw exception(sstream() << "'" << th_name << "' is not a theorem nor an axiom");
}
}
bool enabled(rewrite_rule const & rule) const {
return !m_disabled_rules.contains(rule.m_id);
}
bool enabled(name const & id) const {
return !m_disabled_rules.contains(id);
}
void enable(name const & id, bool f) {
if (f)
m_disabled_rules.erase(id);
else
m_disabled_rules.insert(id);
}
void for_each_match_candidate(expr const &, match_fn const & fn) const {
auto l = m_rule_set;
for (auto const & rule : l) {
if (enabled(rule) && fn(rule.m_lhs, rule.m_ceq, rule.m_is_permutation, rule.m_proof))
return;
}
}
void for_each(visit_fn const & fn) const {
auto l = m_rule_set;
for (auto const & rule : l) {
fn(rule.m_id, rule.m_ceq, rule.m_proof, enabled(rule));
}
}
format pp(formatter const & fmt, options const & opts) const {
format r;
bool first = true;
unsigned indent = get_pp_indent(opts);
for_each([&](name const & name, expr const & ceq, expr const &, bool enabled) {
if (first)
first = false;
else
r += line();
r += format(name);
if (!enabled)
r += format(" [disabled]");
r += format{space(), colon(), space()};
r += nest(indent, fmt(ceq, opts));
});
return r;
}
struct rewrite_rule_set::rewrite_rule {
name m_id;
expr m_lhs;
expr m_ceq;
expr m_proof;
bool m_is_permutation;
rewrite_rule(name const & id, expr const & lhs, expr const & ceq, expr const & proof, bool is_permutation):
m_id(id), m_lhs(lhs), m_ceq(ceq), m_proof(proof), m_is_permutation(is_permutation) {}
};
rewrite_rule_set::rewrite_rule_set(imp * ptr):m_ptr(ptr) {}
rewrite_rule_set::rewrite_rule_set(ro_environment const & env):m_ptr(new imp(env)) {}
rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & rs):m_ptr(new imp(*rs.m_ptr)) {}
rewrite_rule_set::rewrite_rule_set(ro_environment const & env):m_env(env.to_weak_ref()) {}
rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & other):
m_env(other.m_env), m_rule_set(other.m_rule_set), m_disabled_rules(other.m_disabled_rules) {}
rewrite_rule_set::~rewrite_rule_set() {}
void rewrite_rule_set::insert(name const & id, expr const & th, expr const & proof) { m_ptr->insert(id, th, proof); }
void rewrite_rule_set::insert(name const & th_name) { m_ptr->insert(th_name); }
bool rewrite_rule_set::enabled(name const & id) const { return m_ptr->enabled(id); }
void rewrite_rule_set::enable(name const & id, bool f) { m_ptr->enable(id, f); }
void rewrite_rule_set::for_each_match_candidate(expr const & e, match_fn const & fn) const {
m_ptr->for_each_match_candidate(e, fn);
void rewrite_rule_set::insert(name const & id, expr const & th, expr const & proof) {
ro_environment env(m_env);
for (auto const & p : to_ceqs(env, th, proof)) {
expr const & ceq = p.first;
expr const & proof = p.second;
bool is_perm = is_permutation_ceq(ceq);
expr lhs = ceq;
while (is_pi(lhs)) {
lhs = abst_body(lhs);
}
lean_assert(is_equality(lhs));
lhs = arg(lhs, num_args(lhs) - 2);
m_rule_set.emplace_front(id, lhs, ceq, proof, is_perm);
}
}
void rewrite_rule_set::insert(name const & th_name) {
ro_environment env(m_env);
auto obj = env->find_object(th_name);
if (obj && (obj->is_theorem() || obj->is_axiom())) {
insert(th_name, obj->get_type(), mk_constant(th_name));
} else {
throw exception(sstream() << "'" << th_name << "' is not a theorem nor an axiom");
}
}
bool rewrite_rule_set::enabled(rewrite_rule const & rule) const {
return !m_disabled_rules.contains(rule.m_id);
}
bool rewrite_rule_set::enabled(name const & id) const {
return !m_disabled_rules.contains(id);
}
void rewrite_rule_set::enable(name const & id, bool f) {
if (f)
m_disabled_rules.erase(id);
else
m_disabled_rules.insert(id);
}
void rewrite_rule_set::for_each_match_candidate(expr const &, match_fn const & fn) const {
auto l = m_rule_set;
for (auto const & rule : l) {
if (enabled(rule) && fn(rule.m_lhs, rule.m_ceq, rule.m_is_permutation, rule.m_proof))
return;
}
}
void rewrite_rule_set::for_each(visit_fn const & fn) const {
auto l = m_rule_set;
for (auto const & rule : l) {
fn(rule.m_id, rule.m_ceq, rule.m_proof, enabled(rule));
}
}
format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const {
format r;
bool first = true;
unsigned indent = get_pp_indent(opts);
for_each([&](name const & name, expr const & ceq, expr const &, bool enabled) {
if (first)
first = false;
else
r += line();
r += format(name);
if (!enabled)
r += format(" [disabled]");
r += format{space(), colon(), space()};
r += nest(indent, fmt(ceq, opts));
});
return r;
}
void rewrite_rule_set::for_each(visit_fn const & fn) const { m_ptr->for_each(fn); }
format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const { return m_ptr->pp(fmt, opts); }
class mk_rewrite_rule_set_obj : public neutral_object_cell {
name m_rule_set_id;
@ -234,7 +214,7 @@ struct rewrite_rule_set_extension : public environment_extension {
env->add_neutral_object(new enable_rewrite_rules_obj(rule_set_id, rule_id, flag));
}
rewrite_rule_set const & get_rewrite_rule_set(name const & rule_set_id) const {
rewrite_rule_set get_rewrite_rule_set(name const & rule_set_id) const {
return find_ro(rule_set_id);
}
};
@ -258,6 +238,11 @@ static rewrite_rule_set_extension & to_ext(environment const & env) {
return env->get_extension<rewrite_rule_set_extension>(g_rewrite_rule_set_extension_initializer.m_extid);
}
static name g_default_rewrite_rule_set_id("default");
name const & get_default_rewrite_rule_set_id() {
return g_default_rewrite_rule_set_id;
}
void mk_rewrite_rule_set(environment const & env, name const & rule_set_id) {
to_ext(env).mk_rewrite_rule_set(env, rule_set_id);
}
@ -270,13 +255,15 @@ void enable_rewrite_rules(environment const & env, name const & rule_set_id, nam
to_ext(env).enable_rewrite_rules(env, rule_set_id, rule_id, flag);
}
rewrite_rule_set const & get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id) {
rewrite_rule_set get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id) {
return to_ext(env).get_rewrite_rule_set(rule_set_id);
}
static int mk_rewrite_rule_set(lua_State * L) {
int nargs = lua_gettop(L);
if (nargs == 1)
if (nargs == 0)
mk_rewrite_rule_set(rw_shared_environment(L));
else if (nargs == 1)
mk_rewrite_rule_set(rw_shared_environment(L), to_name_ext(L, 1));
else
mk_rewrite_rule_set(rw_shared_environment(L, 2), to_name_ext(L, 1));
@ -285,7 +272,9 @@ static int mk_rewrite_rule_set(lua_State * L) {
static int add_rewrite_rules(lua_State * L) {
int nargs = lua_gettop(L);
if (nargs == 2)
if (nargs == 1)
add_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1));
else if (nargs == 2)
add_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), to_name_ext(L, 2));
else
add_rewrite_rules(rw_shared_environment(L, 3), to_name_ext(L, 1), to_name_ext(L, 2));
@ -294,7 +283,9 @@ static int add_rewrite_rules(lua_State * L) {
static int enable_rewrite_rules(lua_State * L) {
int nargs = lua_gettop(L);
if (nargs == 3)
if (nargs == 2)
enable_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), lua_toboolean(L, 2));
else if (nargs == 3)
enable_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), to_name_ext(L, 2), lua_toboolean(L, 3));
else
enable_rewrite_rules(rw_shared_environment(L, 4), to_name_ext(L, 1), to_name_ext(L, 2), lua_toboolean(L, 3));
@ -306,7 +297,9 @@ static int show_rewrite_rules(lua_State * L) {
formatter fmt = get_global_formatter(L);
options opts = get_global_options(L);
format r;
if (nargs == 1)
if (nargs == 0)
r = get_rewrite_rule_set(ro_shared_environment(L)).pp(fmt, opts);
else if (nargs == 1)
r = get_rewrite_rule_set(ro_shared_environment(L), to_name_ext(L, 1)).pp(fmt, opts);
else
r = get_rewrite_rule_set(ro_shared_environment(L, 2), to_name_ext(L, 1)).pp(fmt, opts);

View file

@ -8,23 +8,35 @@ Author: Leonardo de Moura
#include <memory>
#include <functional>
#include "util/lua.h"
#include "util/list.h"
#include "util/splay_tree.h"
#include "util/name.h"
#include "kernel/environment.h"
#include "kernel/formatter.h"
namespace lean {
/**
\brief Actual implementation of the \c rewrite_rule_set class.
*/
class rewrite_rule_set {
class imp;
std::unique_ptr<imp> m_ptr;
rewrite_rule_set(imp * ptr);
struct rewrite_rule;
typedef splay_tree<name, name_quick_cmp> name_set;
ro_environment::weak_ref m_env;
list<rewrite_rule> m_rule_set; // TODO(Leo): use better data-structure, e.g., discrimination trees
name_set m_disabled_rules;
bool enabled(rewrite_rule const & rule) const;
public:
rewrite_rule_set(ro_environment const & env);
rewrite_rule_set(rewrite_rule_set const & rs);
rewrite_rule_set(rewrite_rule_set const & other);
~rewrite_rule_set();
/**
\brief Convert the theorem \c th with proof \c proof into conditional rewrite rules, and
insert the rules into this rule set. The new rules are tagged with the given \c id.
*/
void insert(name const & id, expr const & th, expr const & proof);
/**
\brief Convert the theorem/axiom named \c th_name in the environment into conditional rewrite rules,
and insert the rules into this rule set. The new rules are tagged with the theorem name.
@ -33,11 +45,12 @@ public:
*/
void insert(name const & th_name);
/** \brief Enable/disable the conditional rewrite rules tagged with the given identifier. */
void enable(name const & id, bool flag);
/** \brief Return true iff the conditional rewrite rules tagged with the given id are enabled. */
bool enabled(name const & id) const;
/** \brief Enable/disable the conditional rewrite rules tagged with the given identifier. */
void enable(name const & id, bool f);
typedef std::function<bool(expr const &, expr const &, bool is_permutation, expr const &)> match_fn; // NOLINT
typedef std::function<void(name const &, expr const &, expr const &, bool)> visit_fn;
@ -51,7 +64,8 @@ public:
The argument \c proof is the proof for \c ceq.
*/
void for_each_match_candidate(expr const & e, match_fn const & fn) const;
void for_each_match_candidate(expr const &, match_fn const & fn) const;
/** \brief Execute <tt>fn(id, ceq, proof, enabled)</tt> for each rule in this rule set. */
void for_each(visit_fn const & fn) const;
@ -59,13 +73,14 @@ public:
format pp(formatter const & fmt, options const & opts) const;
};
name const & get_default_rewrite_rule_set_id();
/**
\brief Create a rewrite rule set inside the given environment.
\remark The rule set is saved when the environment is serialized.
\remark This procedure throws an exception if the environment already contains a rule set named \c rule_set_id.
*/
void mk_rewrite_rule_set(environment const & env, name const & rule_set_id);
void mk_rewrite_rule_set(environment const & env, name const & rule_set_id = get_default_rewrite_rule_set_id());
/**
\brief Convert the theorem named \c th_name into conditional rewrite rules
and insert them in the rule set named \c rule_set_id in the given environment.
@ -74,18 +89,25 @@ void mk_rewrite_rule_set(environment const & env, name const & rule_set_id);
\remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id.
*/
void add_rewrite_rules(environment const & env, name const & rule_set_id, name const & th_name);
inline void add_rewrite_rules(environment const & env, name const & th_name) {
add_rewrite_rules(env, get_default_rewrite_rule_set_id(), th_name);
}
/**
\brief Enable/disable the rewrite rules whose id is \c rule_id in the given rule set.
\remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id.
*/
void enable_rewrite_rules(environment const & env, name const & rule_set_id, name const & rule_id, bool flag);
inline void enable_rewrite_rules(environment const & env, name const & rule_id, bool flag) {
enable_rewrite_rules(env, get_default_rewrite_rule_set_id(), rule_id, flag);
}
/**
\brief Return the rule set name \c rule_set_id in the given environment.
\remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id.
*/
rewrite_rule_set const & get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id);
rewrite_rule_set get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id = get_default_rewrite_rule_set_id());
void open_rewrite_rule_set(lua_State * L);
}

View file

@ -5,6 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <utility>
#include <vector>
#include "util/flet.h"
#include "util/interrupt.h"
#include "kernel/type_checker.h"
@ -15,6 +16,7 @@ Author: Leonardo de Moura
#include "library/kernel_bindings.h"
#include "library/expr_pair.h"
#include "library/hop_match.h"
#include "library/simplifier/rewrite_rule_set.h"
#ifndef LEAN_SIMPLIFIER_PROOFS
#define LEAN_SIMPLIFIER_PROOFS true
@ -75,10 +77,12 @@ unsigned get_simplifier_max_steps(options const & opts) {
}
class simplifier_fn {
typedef std::vector<rewrite_rule_set> rule_sets;
ro_environment m_env;
type_checker m_tc;
bool m_has_heq;
context m_ctx;
rule_sets m_rule_sets;
// Configuration
bool m_proofs_enabled;
@ -385,8 +389,8 @@ class simplifier_fn {
}
public:
simplifier_fn(ro_environment const & env, options const & o):
m_env(env), m_tc(env) {
simplifier_fn(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs):
m_env(env), m_tc(env), m_rule_sets(rs, rs + num_rs) {
m_has_heq = m_env->imported("heq");
set_options(o);
}
@ -402,19 +406,42 @@ public:
}
};
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts) {
return simplifier_fn(env, opts)(e, ctx);
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts,
unsigned num_rs, rewrite_rule_set const * rs) {
return simplifier_fn(env, opts, num_rs, rs)(e, ctx);
}
static int simplify_core(lua_State * L, expr const & e, ro_shared_environment const & env) {
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts,
unsigned num_ns, name const * ns) {
buffer<rewrite_rule_set> rules;
for (unsigned i = 0; i < num_ns; i++)
rules.push_back(get_rewrite_rule_set(env, ns[i]));
return simplify(e, env, ctx, opts, num_ns, rules.data());
}
static int simplify_core(lua_State * L, ro_shared_environment const & env) {
int nargs = lua_gettop(L);
expr const & e = to_expr(L, 1);
buffer<rewrite_rule_set> rules;
if (nargs == 1) {
rules.push_back(get_rewrite_rule_set(env));
} else {
luaL_checktype(L, 2, LUA_TTABLE);
name r;
int n = objlen(L, 2);
for (int i = 1; i <= n; i++) {
lua_rawgeti(L, 2, i);
rules.push_back(get_rewrite_rule_set(env, to_name_ext(L, -1)));
lua_pop(L, 1);
}
}
context ctx;
options opts;
if (nargs >= 3)
ctx = to_context(L, 3);
if (nargs >= 4)
opts = to_options(L, 4);
auto r = simplify(e, env, ctx, opts);
ctx = to_context(L, 4);
if (nargs >= 5)
opts = to_options(L, 5);
auto r = simplify(e, env, ctx, opts, rules.size(), rules.data());
push_expr(L, r.first);
push_expr(L, r.second);
return 2;
@ -422,11 +449,10 @@ static int simplify_core(lua_State * L, expr const & e, ro_shared_environment co
static int simplify(lua_State * L) {
int nargs = lua_gettop(L);
expr const & e = to_expr(L, 1);
if (nargs == 1)
return simplify_core(L, e, ro_shared_environment(L));
if (nargs <= 2)
return simplify_core(L, ro_shared_environment(L));
else
return simplify_core(L, e, ro_shared_environment(L, 2));
return simplify_core(L, ro_shared_environment(L, 3));
}
void open_simplifier(lua_State * L) {

View file

@ -10,6 +10,9 @@ Author: Leonardo de Moura
#include "library/expr_pair.h"
namespace lean {
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts);
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & pts,
unsigned num_rs, rewrite_rule_set const * rs);
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts,
unsigned num_ns, name const * ns);
void open_simplifier(lua_State * L);
}