2014-01-18 20:50:46 +00:00
|
|
|
/*
|
|
|
|
Copyright (c) 2014 Microsoft Corporation. All rights reserved.
|
|
|
|
Released under Apache 2.0 license as described in the file LICENSE.
|
|
|
|
|
|
|
|
Author: Leonardo de Moura
|
|
|
|
*/
|
|
|
|
#include "util/splay_tree.h"
|
|
|
|
#include "util/list_fn.h"
|
|
|
|
#include "util/sstream.h"
|
|
|
|
#include "kernel/environment.h"
|
2014-01-18 23:37:36 +00:00
|
|
|
#include "library/io_state_stream.h"
|
2014-01-18 20:50:46 +00:00
|
|
|
#include "library/equality.h"
|
2014-01-18 23:37:36 +00:00
|
|
|
#include "library/kernel_bindings.h"
|
2014-01-18 20:50:46 +00:00
|
|
|
#include "library/simplifier/ceq.h"
|
|
|
|
#include "library/simplifier/rewrite_rule_set.h"
|
|
|
|
|
|
|
|
namespace lean {
|
2014-01-19 04:52:33 +00:00
|
|
|
rewrite_rule::rewrite_rule(name const & id, expr const & lhs, expr const & rhs, expr const & ceq, expr const & proof,
|
|
|
|
unsigned num_args, bool is_permutation):
|
|
|
|
m_id(id), m_lhs(lhs), m_rhs(rhs), m_ceq(ceq), m_proof(proof), m_num_args(num_args), m_is_permutation(is_permutation) {
|
|
|
|
}
|
2014-01-18 20:50:46 +00:00
|
|
|
|
2014-01-19 01:23:41 +00:00
|
|
|
rewrite_rule_set::rewrite_rule_set(ro_environment const & env):m_env(env.to_weak_ref()) {}
|
|
|
|
rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & other):
|
|
|
|
m_env(other.m_env), m_rule_set(other.m_rule_set), m_disabled_rules(other.m_disabled_rules) {}
|
|
|
|
rewrite_rule_set::~rewrite_rule_set() {}
|
2014-01-18 20:50:46 +00:00
|
|
|
|
2014-01-19 01:23:41 +00:00
|
|
|
void rewrite_rule_set::insert(name const & id, expr const & th, expr const & proof) {
|
|
|
|
ro_environment env(m_env);
|
|
|
|
for (auto const & p : to_ceqs(env, th, proof)) {
|
|
|
|
expr const & ceq = p.first;
|
|
|
|
expr const & proof = p.second;
|
|
|
|
bool is_perm = is_permutation_ceq(ceq);
|
2014-01-19 04:52:33 +00:00
|
|
|
expr eq = ceq;
|
|
|
|
unsigned num = 0;
|
|
|
|
while (is_pi(eq)) {
|
|
|
|
eq = abst_body(eq);
|
|
|
|
num++;
|
2014-01-18 20:50:46 +00:00
|
|
|
}
|
2014-01-19 04:52:33 +00:00
|
|
|
lean_assert(is_equality(eq));
|
|
|
|
m_rule_set = cons(rewrite_rule(id, arg(eq, num_args(eq) - 2), arg(eq, num_args(eq) - 1),
|
|
|
|
ceq, proof, num, is_perm),
|
|
|
|
m_rule_set);
|
2014-01-18 20:50:46 +00:00
|
|
|
}
|
2014-01-19 01:23:41 +00:00
|
|
|
}
|
2014-01-18 20:50:46 +00:00
|
|
|
|
2014-01-19 01:23:41 +00:00
|
|
|
void rewrite_rule_set::insert(name const & th_name) {
|
|
|
|
ro_environment env(m_env);
|
|
|
|
auto obj = env->find_object(th_name);
|
|
|
|
if (obj && (obj->is_theorem() || obj->is_axiom())) {
|
|
|
|
insert(th_name, obj->get_type(), mk_constant(th_name));
|
|
|
|
} else {
|
|
|
|
throw exception(sstream() << "'" << th_name << "' is not a theorem nor an axiom");
|
2014-01-18 20:50:46 +00:00
|
|
|
}
|
2014-01-19 01:23:41 +00:00
|
|
|
}
|
2014-01-18 20:50:46 +00:00
|
|
|
|
2014-01-19 01:23:41 +00:00
|
|
|
bool rewrite_rule_set::enabled(rewrite_rule const & rule) const {
|
2014-01-19 04:52:33 +00:00
|
|
|
return !m_disabled_rules.contains(rule.get_id());
|
2014-01-19 01:23:41 +00:00
|
|
|
}
|
2014-01-18 20:50:46 +00:00
|
|
|
|
2014-01-19 01:23:41 +00:00
|
|
|
bool rewrite_rule_set::enabled(name const & id) const {
|
|
|
|
return !m_disabled_rules.contains(id);
|
|
|
|
}
|
2014-01-18 20:50:46 +00:00
|
|
|
|
2014-01-19 01:23:41 +00:00
|
|
|
void rewrite_rule_set::enable(name const & id, bool f) {
|
|
|
|
if (f)
|
|
|
|
m_disabled_rules.erase(id);
|
|
|
|
else
|
|
|
|
m_disabled_rules.insert(id);
|
|
|
|
}
|
2014-01-18 20:50:46 +00:00
|
|
|
|
2014-01-19 04:52:33 +00:00
|
|
|
bool rewrite_rule_set::find_match(expr const &, match_fn const & fn) const {
|
2014-01-19 01:23:41 +00:00
|
|
|
auto l = m_rule_set;
|
|
|
|
for (auto const & rule : l) {
|
2014-01-19 04:52:33 +00:00
|
|
|
if (enabled(rule) && fn(rule))
|
|
|
|
return true;
|
2014-01-18 20:50:46 +00:00
|
|
|
}
|
2014-01-19 04:52:33 +00:00
|
|
|
return false;
|
2014-01-19 01:23:41 +00:00
|
|
|
}
|
2014-01-18 20:50:46 +00:00
|
|
|
|
2014-01-19 01:23:41 +00:00
|
|
|
void rewrite_rule_set::for_each(visit_fn const & fn) const {
|
|
|
|
auto l = m_rule_set;
|
|
|
|
for (auto const & rule : l) {
|
2014-01-19 04:52:33 +00:00
|
|
|
fn(rule, enabled(rule));
|
2014-01-18 20:50:46 +00:00
|
|
|
}
|
2014-01-18 23:37:36 +00:00
|
|
|
}
|
2014-01-19 01:23:41 +00:00
|
|
|
|
|
|
|
format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const {
|
|
|
|
format r;
|
|
|
|
bool first = true;
|
|
|
|
unsigned indent = get_pp_indent(opts);
|
2014-01-19 04:52:33 +00:00
|
|
|
for_each([&](rewrite_rule const & rule, bool enabled) {
|
2014-01-19 01:23:41 +00:00
|
|
|
if (first)
|
|
|
|
first = false;
|
|
|
|
else
|
|
|
|
r += line();
|
2014-01-19 04:52:33 +00:00
|
|
|
r += format(rule.get_id());
|
2014-01-19 01:23:41 +00:00
|
|
|
if (!enabled)
|
|
|
|
r += format(" [disabled]");
|
|
|
|
r += format{space(), colon(), space()};
|
2014-01-19 04:52:33 +00:00
|
|
|
r += nest(indent, fmt(rule.get_ceq(), opts));
|
2014-01-19 01:23:41 +00:00
|
|
|
});
|
|
|
|
return r;
|
|
|
|
}
|
2014-01-18 23:37:36 +00:00
|
|
|
|
|
|
|
class mk_rewrite_rule_set_obj : public neutral_object_cell {
|
|
|
|
name m_rule_set_id;
|
|
|
|
public:
|
|
|
|
mk_rewrite_rule_set_obj(name const & id):m_rule_set_id(id) {}
|
|
|
|
virtual ~mk_rewrite_rule_set_obj() {}
|
|
|
|
virtual char const * keyword() const { return "mk_rewrite_rule_set"; }
|
|
|
|
virtual void write(serializer & s) const { s << "mk_rrs" << m_rule_set_id; }
|
|
|
|
};
|
|
|
|
static void read_rrs(environment const & env, io_state const &, deserializer & d) {
|
|
|
|
name n = read_name(d);
|
|
|
|
mk_rewrite_rule_set(env, n);
|
|
|
|
}
|
|
|
|
static object_cell::register_deserializer_fn mk_rrs_ds("mk_rrs", read_rrs);
|
|
|
|
|
|
|
|
class add_rewrite_rules_obj : public neutral_object_cell {
|
|
|
|
name m_rule_set_id;
|
|
|
|
name m_th_name;
|
|
|
|
public:
|
|
|
|
add_rewrite_rules_obj(name const & rsid, name const & th_name):m_rule_set_id(rsid), m_th_name(th_name) {}
|
|
|
|
virtual ~add_rewrite_rules_obj() {}
|
|
|
|
virtual char const * keyword() const { return "add_rewrite_rules"; }
|
|
|
|
virtual void write(serializer & s) const { s << "add_rr" << m_rule_set_id << m_th_name; }
|
|
|
|
};
|
|
|
|
static void read_arr(environment const & env, io_state const &, deserializer & d) {
|
|
|
|
name rsid = read_name(d);
|
|
|
|
name th = read_name(d);
|
|
|
|
add_rewrite_rules(env, rsid, th);
|
|
|
|
}
|
|
|
|
static object_cell::register_deserializer_fn add_rr_ds("add_rr", read_arr);
|
|
|
|
|
|
|
|
class enable_rewrite_rules_obj : public neutral_object_cell {
|
|
|
|
name m_rule_set_id;
|
|
|
|
name m_rule_id;
|
|
|
|
bool m_flag;
|
|
|
|
public:
|
|
|
|
enable_rewrite_rules_obj(name const & rsid, name const & id, bool flag):m_rule_set_id(rsid), m_rule_id(id), m_flag(flag) {}
|
|
|
|
virtual ~enable_rewrite_rules_obj() {}
|
|
|
|
virtual char const * keyword() const { return "enable_rewrite_rules_obj"; }
|
|
|
|
virtual void write(serializer & s) const { s << "enable_rr" << m_rule_set_id << m_rule_id << m_flag; }
|
|
|
|
};
|
|
|
|
static void read_enable_rr(environment const & env, io_state const &, deserializer & d) {
|
|
|
|
name rsid = read_name(d);
|
|
|
|
name id = read_name(d);
|
|
|
|
bool flag = d.read_bool();
|
|
|
|
enable_rewrite_rules(env, rsid, id, flag);
|
|
|
|
}
|
|
|
|
static object_cell::register_deserializer_fn enable_rr_ds("enable_rr", read_enable_rr);
|
|
|
|
|
|
|
|
/**
|
|
|
|
\brief Extension for managing rewrite rule sets.
|
|
|
|
*/
|
|
|
|
struct rewrite_rule_set_extension : public environment_extension {
|
|
|
|
name_map<rewrite_rule_set> m_rule_sets;
|
|
|
|
|
|
|
|
rewrite_rule_set_extension const * get_parent() const {
|
|
|
|
return environment_extension::get_parent<rewrite_rule_set_extension>();
|
|
|
|
}
|
|
|
|
|
|
|
|
rewrite_rule_set const * find_ro_core(name const & rule_set_id) const {
|
|
|
|
auto it = m_rule_sets.find(rule_set_id);
|
|
|
|
if (it != m_rule_sets.end()) {
|
|
|
|
return &(it->second);
|
|
|
|
}
|
|
|
|
auto p = get_parent();
|
|
|
|
if (p) {
|
|
|
|
return p->find_ro_core(rule_set_id);
|
|
|
|
} else {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
rewrite_rule_set const & find_ro(name const & rule_set_id) const {
|
|
|
|
auto rs = find_ro_core(rule_set_id);
|
|
|
|
if (rs)
|
|
|
|
return *rs;
|
|
|
|
throw exception(sstream() << "environment does not contain a rewrite rule set named '" << rule_set_id << "'");
|
|
|
|
}
|
|
|
|
|
|
|
|
rewrite_rule_set & find_rw(name const & rule_set_id) {
|
|
|
|
auto it = m_rule_sets.find(rule_set_id);
|
|
|
|
if (it != m_rule_sets.end())
|
|
|
|
return it->second;
|
|
|
|
auto p = get_parent();
|
|
|
|
if (p) {
|
|
|
|
auto const & rs = p->find_ro(rule_set_id);
|
|
|
|
m_rule_sets.insert(mk_pair(rule_set_id, rewrite_rule_set(rs)));
|
|
|
|
return m_rule_sets.find(rule_set_id)->second;
|
|
|
|
}
|
|
|
|
throw exception(sstream() << "environment does not contain a rewrite rule set named '" << rule_set_id << "'");
|
|
|
|
}
|
|
|
|
|
|
|
|
void mk_rewrite_rule_set(environment const & env, name const & rule_set_id) {
|
|
|
|
if (find_ro_core(rule_set_id))
|
|
|
|
throw exception(sstream() << "environment already contains a rewrite rule set named '" << rule_set_id << "'");
|
|
|
|
m_rule_sets.insert(mk_pair(rule_set_id, rewrite_rule_set(env)));
|
|
|
|
env->add_neutral_object(new mk_rewrite_rule_set_obj(rule_set_id));
|
|
|
|
}
|
|
|
|
|
|
|
|
void add_rewrite_rules(environment const & env, name const & rule_set_id, name const & th_name) {
|
|
|
|
auto & rs = find_rw(rule_set_id);
|
|
|
|
rs.insert(th_name);
|
|
|
|
env->add_neutral_object(new add_rewrite_rules_obj(rule_set_id, th_name));
|
|
|
|
}
|
|
|
|
|
|
|
|
void enable_rewrite_rules(environment const & env, name const & rule_set_id, name const & rule_id, bool flag) {
|
|
|
|
auto & rs = find_rw(rule_set_id);
|
|
|
|
rs.enable(rule_id, flag);
|
|
|
|
env->add_neutral_object(new enable_rewrite_rules_obj(rule_set_id, rule_id, flag));
|
|
|
|
}
|
|
|
|
|
2014-01-19 01:23:41 +00:00
|
|
|
rewrite_rule_set get_rewrite_rule_set(name const & rule_set_id) const {
|
2014-01-18 23:37:36 +00:00
|
|
|
return find_ro(rule_set_id);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
struct rewrite_rule_set_extension_initializer {
|
|
|
|
unsigned m_extid;
|
|
|
|
rewrite_rule_set_extension_initializer() {
|
|
|
|
m_extid = environment_cell::register_extension([](){
|
|
|
|
return std::unique_ptr<environment_extension>(new rewrite_rule_set_extension());
|
|
|
|
});
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
static rewrite_rule_set_extension_initializer g_rewrite_rule_set_extension_initializer;
|
|
|
|
|
|
|
|
static rewrite_rule_set_extension const & to_ext(ro_environment const & env) {
|
|
|
|
return env->get_extension<rewrite_rule_set_extension>(g_rewrite_rule_set_extension_initializer.m_extid);
|
|
|
|
}
|
|
|
|
|
|
|
|
static rewrite_rule_set_extension & to_ext(environment const & env) {
|
|
|
|
return env->get_extension<rewrite_rule_set_extension>(g_rewrite_rule_set_extension_initializer.m_extid);
|
|
|
|
}
|
|
|
|
|
2014-01-19 01:23:41 +00:00
|
|
|
static name g_default_rewrite_rule_set_id("default");
|
|
|
|
name const & get_default_rewrite_rule_set_id() {
|
|
|
|
return g_default_rewrite_rule_set_id;
|
|
|
|
}
|
|
|
|
|
2014-01-18 23:37:36 +00:00
|
|
|
void mk_rewrite_rule_set(environment const & env, name const & rule_set_id) {
|
|
|
|
to_ext(env).mk_rewrite_rule_set(env, rule_set_id);
|
|
|
|
}
|
|
|
|
|
|
|
|
void add_rewrite_rules(environment const & env, name const & rule_set_id, name const & th_name) {
|
|
|
|
to_ext(env).add_rewrite_rules(env, rule_set_id, th_name);
|
|
|
|
}
|
|
|
|
|
|
|
|
void enable_rewrite_rules(environment const & env, name const & rule_set_id, name const & rule_id, bool flag) {
|
|
|
|
to_ext(env).enable_rewrite_rules(env, rule_set_id, rule_id, flag);
|
|
|
|
}
|
|
|
|
|
2014-01-19 01:23:41 +00:00
|
|
|
rewrite_rule_set get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id) {
|
2014-01-18 23:37:36 +00:00
|
|
|
return to_ext(env).get_rewrite_rule_set(rule_set_id);
|
|
|
|
}
|
|
|
|
|
|
|
|
static int mk_rewrite_rule_set(lua_State * L) {
|
|
|
|
int nargs = lua_gettop(L);
|
2014-01-19 01:23:41 +00:00
|
|
|
if (nargs == 0)
|
|
|
|
mk_rewrite_rule_set(rw_shared_environment(L));
|
|
|
|
else if (nargs == 1)
|
2014-01-18 23:37:36 +00:00
|
|
|
mk_rewrite_rule_set(rw_shared_environment(L), to_name_ext(L, 1));
|
|
|
|
else
|
|
|
|
mk_rewrite_rule_set(rw_shared_environment(L, 2), to_name_ext(L, 1));
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
static int add_rewrite_rules(lua_State * L) {
|
|
|
|
int nargs = lua_gettop(L);
|
2014-01-19 01:23:41 +00:00
|
|
|
if (nargs == 1)
|
|
|
|
add_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1));
|
|
|
|
else if (nargs == 2)
|
2014-01-18 23:37:36 +00:00
|
|
|
add_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), to_name_ext(L, 2));
|
|
|
|
else
|
|
|
|
add_rewrite_rules(rw_shared_environment(L, 3), to_name_ext(L, 1), to_name_ext(L, 2));
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
static int enable_rewrite_rules(lua_State * L) {
|
|
|
|
int nargs = lua_gettop(L);
|
2014-01-19 01:23:41 +00:00
|
|
|
if (nargs == 2)
|
|
|
|
enable_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), lua_toboolean(L, 2));
|
|
|
|
else if (nargs == 3)
|
2014-01-18 23:37:36 +00:00
|
|
|
enable_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), to_name_ext(L, 2), lua_toboolean(L, 3));
|
|
|
|
else
|
|
|
|
enable_rewrite_rules(rw_shared_environment(L, 4), to_name_ext(L, 1), to_name_ext(L, 2), lua_toboolean(L, 3));
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
static int show_rewrite_rules(lua_State * L) {
|
|
|
|
int nargs = lua_gettop(L);
|
|
|
|
formatter fmt = get_global_formatter(L);
|
|
|
|
options opts = get_global_options(L);
|
|
|
|
format r;
|
2014-01-19 01:23:41 +00:00
|
|
|
if (nargs == 0)
|
|
|
|
r = get_rewrite_rule_set(ro_shared_environment(L)).pp(fmt, opts);
|
|
|
|
else if (nargs == 1)
|
2014-01-18 23:37:36 +00:00
|
|
|
r = get_rewrite_rule_set(ro_shared_environment(L), to_name_ext(L, 1)).pp(fmt, opts);
|
|
|
|
else
|
|
|
|
r = get_rewrite_rule_set(ro_shared_environment(L, 2), to_name_ext(L, 1)).pp(fmt, opts);
|
|
|
|
io_state * ios = get_io_state(L);
|
|
|
|
if (ios)
|
|
|
|
regular(*ios) << mk_pair(r, opts) << endl;
|
|
|
|
else
|
|
|
|
std::cout << mk_pair(r, opts) << "\n";
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
void open_rewrite_rule_set(lua_State * L) {
|
|
|
|
SET_GLOBAL_FUN(mk_rewrite_rule_set, "mk_rewrite_rule_set");
|
|
|
|
SET_GLOBAL_FUN(add_rewrite_rules, "add_rewrite_rules");
|
|
|
|
SET_GLOBAL_FUN(enable_rewrite_rules, "enable_rewrite_rules");
|
|
|
|
SET_GLOBAL_FUN(show_rewrite_rules, "show_rewrite_rules");
|
|
|
|
}
|
2014-01-18 20:50:46 +00:00
|
|
|
}
|