feat(library/simplifier): add rewrite_rule_set extension for managing rewrite rules in an environment
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
eae79877ae
commit
feea96e84d
7 changed files with 284 additions and 22 deletions
|
@ -7,11 +7,13 @@ Author: Leonardo de Moura
|
|||
#pragma once
|
||||
#include "util/script_state.h"
|
||||
#include "library/simplifier/ceq.h"
|
||||
#include "library/simplifier/rewrite_rule_set.h"
|
||||
#include "library/simplifier/simplifier.h"
|
||||
|
||||
namespace lean {
|
||||
inline void open_simplifier_module(lua_State * L) {
|
||||
open_ceq(L);
|
||||
open_rewrite_rule_set(L);
|
||||
open_simplifier(L);
|
||||
}
|
||||
inline void register_simplifier_module() {
|
||||
|
|
|
@ -8,7 +8,9 @@ Author: Leonardo de Moura
|
|||
#include "util/list_fn.h"
|
||||
#include "util/sstream.h"
|
||||
#include "kernel/environment.h"
|
||||
#include "library/io_state_stream.h"
|
||||
#include "library/equality.h"
|
||||
#include "library/kernel_bindings.h"
|
||||
#include "library/simplifier/ceq.h"
|
||||
#include "library/simplifier/rewrite_rule_set.h"
|
||||
|
||||
|
@ -108,14 +110,218 @@ public:
|
|||
return r;
|
||||
}
|
||||
};
|
||||
rewrite_rule_set::rewrite_rule_set(imp * ptr):m_ptr(ptr) {}
|
||||
rewrite_rule_set::rewrite_rule_set(ro_environment const & env):m_ptr(new imp(env)) {}
|
||||
rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & rs):m_ptr(new imp(*(rs.m_ptr))) {}
|
||||
rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & rs):m_ptr(new imp(*rs.m_ptr)) {}
|
||||
rewrite_rule_set::~rewrite_rule_set() {}
|
||||
void rewrite_rule_set::insert(name const & id, expr const & th, expr const & proof) { m_ptr->insert(id, th, proof); }
|
||||
void rewrite_rule_set::insert(name const & th_name) { m_ptr->insert(th_name); }
|
||||
bool rewrite_rule_set::enabled(name const & id) const { return m_ptr->enabled(id); }
|
||||
void rewrite_rule_set::enable(name const & id, bool f) { m_ptr->enable(id, f); }
|
||||
void rewrite_rule_set::for_each_match_candidate(expr const & e, match_fn const & fn) { m_ptr->for_each_match_candidate(e, fn); }
|
||||
void rewrite_rule_set::for_each(visit_fn const & fn) { m_ptr->for_each(fn); }
|
||||
format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const { return m_ptr->pp(fmt, opts); }
|
||||
void rewrite_rule_set::for_each_match_candidate(expr const & e, match_fn const & fn) const {
|
||||
m_ptr->for_each_match_candidate(e, fn);
|
||||
}
|
||||
void rewrite_rule_set::for_each(visit_fn const & fn) const { m_ptr->for_each(fn); }
|
||||
format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const { return m_ptr->pp(fmt, opts); }
|
||||
|
||||
class mk_rewrite_rule_set_obj : public neutral_object_cell {
|
||||
name m_rule_set_id;
|
||||
public:
|
||||
mk_rewrite_rule_set_obj(name const & id):m_rule_set_id(id) {}
|
||||
virtual ~mk_rewrite_rule_set_obj() {}
|
||||
virtual char const * keyword() const { return "mk_rewrite_rule_set"; }
|
||||
virtual void write(serializer & s) const { s << "mk_rrs" << m_rule_set_id; }
|
||||
};
|
||||
static void read_rrs(environment const & env, io_state const &, deserializer & d) {
|
||||
name n = read_name(d);
|
||||
mk_rewrite_rule_set(env, n);
|
||||
}
|
||||
static object_cell::register_deserializer_fn mk_rrs_ds("mk_rrs", read_rrs);
|
||||
|
||||
class add_rewrite_rules_obj : public neutral_object_cell {
|
||||
name m_rule_set_id;
|
||||
name m_th_name;
|
||||
public:
|
||||
add_rewrite_rules_obj(name const & rsid, name const & th_name):m_rule_set_id(rsid), m_th_name(th_name) {}
|
||||
virtual ~add_rewrite_rules_obj() {}
|
||||
virtual char const * keyword() const { return "add_rewrite_rules"; }
|
||||
virtual void write(serializer & s) const { s << "add_rr" << m_rule_set_id << m_th_name; }
|
||||
};
|
||||
static void read_arr(environment const & env, io_state const &, deserializer & d) {
|
||||
name rsid = read_name(d);
|
||||
name th = read_name(d);
|
||||
add_rewrite_rules(env, rsid, th);
|
||||
}
|
||||
static object_cell::register_deserializer_fn add_rr_ds("add_rr", read_arr);
|
||||
|
||||
class enable_rewrite_rules_obj : public neutral_object_cell {
|
||||
name m_rule_set_id;
|
||||
name m_rule_id;
|
||||
bool m_flag;
|
||||
public:
|
||||
enable_rewrite_rules_obj(name const & rsid, name const & id, bool flag):m_rule_set_id(rsid), m_rule_id(id), m_flag(flag) {}
|
||||
virtual ~enable_rewrite_rules_obj() {}
|
||||
virtual char const * keyword() const { return "enable_rewrite_rules_obj"; }
|
||||
virtual void write(serializer & s) const { s << "enable_rr" << m_rule_set_id << m_rule_id << m_flag; }
|
||||
};
|
||||
static void read_enable_rr(environment const & env, io_state const &, deserializer & d) {
|
||||
name rsid = read_name(d);
|
||||
name id = read_name(d);
|
||||
bool flag = d.read_bool();
|
||||
enable_rewrite_rules(env, rsid, id, flag);
|
||||
}
|
||||
static object_cell::register_deserializer_fn enable_rr_ds("enable_rr", read_enable_rr);
|
||||
|
||||
/**
|
||||
\brief Extension for managing rewrite rule sets.
|
||||
*/
|
||||
struct rewrite_rule_set_extension : public environment_extension {
|
||||
name_map<rewrite_rule_set> m_rule_sets;
|
||||
|
||||
rewrite_rule_set_extension const * get_parent() const {
|
||||
return environment_extension::get_parent<rewrite_rule_set_extension>();
|
||||
}
|
||||
|
||||
rewrite_rule_set const * find_ro_core(name const & rule_set_id) const {
|
||||
auto it = m_rule_sets.find(rule_set_id);
|
||||
if (it != m_rule_sets.end()) {
|
||||
return &(it->second);
|
||||
}
|
||||
auto p = get_parent();
|
||||
if (p) {
|
||||
return p->find_ro_core(rule_set_id);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
rewrite_rule_set const & find_ro(name const & rule_set_id) const {
|
||||
auto rs = find_ro_core(rule_set_id);
|
||||
if (rs)
|
||||
return *rs;
|
||||
throw exception(sstream() << "environment does not contain a rewrite rule set named '" << rule_set_id << "'");
|
||||
}
|
||||
|
||||
rewrite_rule_set & find_rw(name const & rule_set_id) {
|
||||
auto it = m_rule_sets.find(rule_set_id);
|
||||
if (it != m_rule_sets.end())
|
||||
return it->second;
|
||||
auto p = get_parent();
|
||||
if (p) {
|
||||
auto const & rs = p->find_ro(rule_set_id);
|
||||
m_rule_sets.insert(mk_pair(rule_set_id, rewrite_rule_set(rs)));
|
||||
return m_rule_sets.find(rule_set_id)->second;
|
||||
}
|
||||
throw exception(sstream() << "environment does not contain a rewrite rule set named '" << rule_set_id << "'");
|
||||
}
|
||||
|
||||
void mk_rewrite_rule_set(environment const & env, name const & rule_set_id) {
|
||||
if (find_ro_core(rule_set_id))
|
||||
throw exception(sstream() << "environment already contains a rewrite rule set named '" << rule_set_id << "'");
|
||||
m_rule_sets.insert(mk_pair(rule_set_id, rewrite_rule_set(env)));
|
||||
env->add_neutral_object(new mk_rewrite_rule_set_obj(rule_set_id));
|
||||
}
|
||||
|
||||
void add_rewrite_rules(environment const & env, name const & rule_set_id, name const & th_name) {
|
||||
auto & rs = find_rw(rule_set_id);
|
||||
rs.insert(th_name);
|
||||
env->add_neutral_object(new add_rewrite_rules_obj(rule_set_id, th_name));
|
||||
}
|
||||
|
||||
void enable_rewrite_rules(environment const & env, name const & rule_set_id, name const & rule_id, bool flag) {
|
||||
auto & rs = find_rw(rule_set_id);
|
||||
rs.enable(rule_id, flag);
|
||||
env->add_neutral_object(new enable_rewrite_rules_obj(rule_set_id, rule_id, flag));
|
||||
}
|
||||
|
||||
rewrite_rule_set const & get_rewrite_rule_set(name const & rule_set_id) const {
|
||||
return find_ro(rule_set_id);
|
||||
}
|
||||
};
|
||||
|
||||
struct rewrite_rule_set_extension_initializer {
|
||||
unsigned m_extid;
|
||||
rewrite_rule_set_extension_initializer() {
|
||||
m_extid = environment_cell::register_extension([](){
|
||||
return std::unique_ptr<environment_extension>(new rewrite_rule_set_extension());
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
static rewrite_rule_set_extension_initializer g_rewrite_rule_set_extension_initializer;
|
||||
|
||||
static rewrite_rule_set_extension const & to_ext(ro_environment const & env) {
|
||||
return env->get_extension<rewrite_rule_set_extension>(g_rewrite_rule_set_extension_initializer.m_extid);
|
||||
}
|
||||
|
||||
static rewrite_rule_set_extension & to_ext(environment const & env) {
|
||||
return env->get_extension<rewrite_rule_set_extension>(g_rewrite_rule_set_extension_initializer.m_extid);
|
||||
}
|
||||
|
||||
void mk_rewrite_rule_set(environment const & env, name const & rule_set_id) {
|
||||
to_ext(env).mk_rewrite_rule_set(env, rule_set_id);
|
||||
}
|
||||
|
||||
void add_rewrite_rules(environment const & env, name const & rule_set_id, name const & th_name) {
|
||||
to_ext(env).add_rewrite_rules(env, rule_set_id, th_name);
|
||||
}
|
||||
|
||||
void enable_rewrite_rules(environment const & env, name const & rule_set_id, name const & rule_id, bool flag) {
|
||||
to_ext(env).enable_rewrite_rules(env, rule_set_id, rule_id, flag);
|
||||
}
|
||||
|
||||
rewrite_rule_set const & get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id) {
|
||||
return to_ext(env).get_rewrite_rule_set(rule_set_id);
|
||||
}
|
||||
|
||||
static int mk_rewrite_rule_set(lua_State * L) {
|
||||
int nargs = lua_gettop(L);
|
||||
if (nargs == 1)
|
||||
mk_rewrite_rule_set(rw_shared_environment(L), to_name_ext(L, 1));
|
||||
else
|
||||
mk_rewrite_rule_set(rw_shared_environment(L, 2), to_name_ext(L, 1));
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int add_rewrite_rules(lua_State * L) {
|
||||
int nargs = lua_gettop(L);
|
||||
if (nargs == 2)
|
||||
add_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), to_name_ext(L, 2));
|
||||
else
|
||||
add_rewrite_rules(rw_shared_environment(L, 3), to_name_ext(L, 1), to_name_ext(L, 2));
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int enable_rewrite_rules(lua_State * L) {
|
||||
int nargs = lua_gettop(L);
|
||||
if (nargs == 3)
|
||||
enable_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), to_name_ext(L, 2), lua_toboolean(L, 3));
|
||||
else
|
||||
enable_rewrite_rules(rw_shared_environment(L, 4), to_name_ext(L, 1), to_name_ext(L, 2), lua_toboolean(L, 3));
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int show_rewrite_rules(lua_State * L) {
|
||||
int nargs = lua_gettop(L);
|
||||
formatter fmt = get_global_formatter(L);
|
||||
options opts = get_global_options(L);
|
||||
format r;
|
||||
if (nargs == 1)
|
||||
r = get_rewrite_rule_set(ro_shared_environment(L), to_name_ext(L, 1)).pp(fmt, opts);
|
||||
else
|
||||
r = get_rewrite_rule_set(ro_shared_environment(L, 2), to_name_ext(L, 1)).pp(fmt, opts);
|
||||
io_state * ios = get_io_state(L);
|
||||
if (ios)
|
||||
regular(*ios) << mk_pair(r, opts) << endl;
|
||||
else
|
||||
std::cout << mk_pair(r, opts) << "\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
void open_rewrite_rule_set(lua_State * L) {
|
||||
SET_GLOBAL_FUN(mk_rewrite_rule_set, "mk_rewrite_rule_set");
|
||||
SET_GLOBAL_FUN(add_rewrite_rules, "add_rewrite_rules");
|
||||
SET_GLOBAL_FUN(enable_rewrite_rules, "enable_rewrite_rules");
|
||||
SET_GLOBAL_FUN(show_rewrite_rules, "show_rewrite_rules");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ Author: Leonardo de Moura
|
|||
#pragma once
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include "util/lua.h"
|
||||
#include "kernel/environment.h"
|
||||
#include "kernel/formatter.h"
|
||||
|
||||
|
@ -14,6 +15,7 @@ namespace lean {
|
|||
class rewrite_rule_set {
|
||||
class imp;
|
||||
std::unique_ptr<imp> m_ptr;
|
||||
rewrite_rule_set(imp * ptr);
|
||||
public:
|
||||
rewrite_rule_set(ro_environment const & env);
|
||||
rewrite_rule_set(rewrite_rule_set const & rs);
|
||||
|
@ -31,16 +33,12 @@ public:
|
|||
*/
|
||||
void insert(name const & th_name);
|
||||
|
||||
/**
|
||||
\brief Enable/disable the conditional rewrite rules tagged with the given identifier.
|
||||
*/
|
||||
/** \brief Enable/disable the conditional rewrite rules tagged with the given identifier. */
|
||||
void enable(name const & id, bool flag);
|
||||
/**
|
||||
\brief Return true iff the conditional rewrite rules tagged with the given id are enabled.
|
||||
*/
|
||||
/** \brief Return true iff the conditional rewrite rules tagged with the given id are enabled. */
|
||||
bool enabled(name const & id) const;
|
||||
|
||||
typedef std::function<bool(expr const &, expr const &, bool is_permutation, expr const &)> match_fn;
|
||||
typedef std::function<bool(expr const &, expr const &, bool is_permutation, expr const &)> match_fn; // NOLINT
|
||||
typedef std::function<void(name const &, expr const &, expr const &, bool)> visit_fn;
|
||||
|
||||
/**
|
||||
|
@ -53,16 +51,41 @@ public:
|
|||
|
||||
The argument \c proof is the proof for \c ceq.
|
||||
*/
|
||||
void for_each_match_candidate(expr const & e, match_fn const & fn);
|
||||
void for_each_match_candidate(expr const & e, match_fn const & fn) const;
|
||||
/** \brief Execute <tt>fn(id, ceq, proof, enabled)</tt> for each rule in this rule set. */
|
||||
void for_each(visit_fn const & fn) const;
|
||||
|
||||
/**
|
||||
\brief Execute <tt>fn(id, ceq, proof, enabled)</tt> for each rule in this rule set.
|
||||
*/
|
||||
void for_each(visit_fn const & fn);
|
||||
|
||||
/**
|
||||
\brief Pretty print this rule set.
|
||||
*/
|
||||
/** \brief Pretty print this rule set. */
|
||||
format pp(formatter const & fmt, options const & opts) const;
|
||||
};
|
||||
|
||||
/**
|
||||
\brief Create a rewrite rule set inside the given environment.
|
||||
|
||||
\remark The rule set is saved when the environment is serialized.
|
||||
\remark This procedure throws an exception if the environment already contains a rule set named \c rule_set_id.
|
||||
*/
|
||||
void mk_rewrite_rule_set(environment const & env, name const & rule_set_id);
|
||||
/**
|
||||
\brief Convert the theorem named \c th_name into conditional rewrite rules
|
||||
and insert them in the rule set named \c rule_set_id in the given environment.
|
||||
|
||||
\remark This procedure throws an exception if the environment does not have a theorem/axiom named \c th_name.
|
||||
\remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id.
|
||||
*/
|
||||
void add_rewrite_rules(environment const & env, name const & rule_set_id, name const & th_name);
|
||||
/**
|
||||
\brief Enable/disable the rewrite rules whose id is \c rule_id in the given rule set.
|
||||
|
||||
\remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id.
|
||||
*/
|
||||
void enable_rewrite_rules(environment const & env, name const & rule_set_id, name const & rule_id, bool flag);
|
||||
/**
|
||||
\brief Return the rule set name \c rule_set_id in the given environment.
|
||||
|
||||
\remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id.
|
||||
*/
|
||||
rewrite_rule_set const & get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id);
|
||||
|
||||
void open_rewrite_rule_set(lua_State * L);
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
Set: pp::colors
|
||||
Set: pp::unicode
|
||||
Imported 'find'
|
||||
theorem congr1 {A : TypeU} {B : A → TypeU} {f g : ∀ x : A, B x} (a : A) (H : f = g) : f a = g a
|
||||
theorem congr1 {A B : TypeU} {f g : A → B} (a : A) (H : f = g) : f a = g a
|
||||
theorem congr2 {A B : TypeU} {a b : A} (f : A → B) (H : a = b) : f a = f b
|
||||
theorem congr {A B : TypeU} {f g : A → B} {a b : A} (H1 : f = g) (H2 : a = b) : f a = g b
|
||||
find.lean:3:0: error: executing external script (/home/leo/projects/lean/build/debug/shell/find.lua:24), no object name in the environment matches the regular expression 'foo'
|
||||
|
|
|
@ -10,7 +10,7 @@ Failed to solve
|
|||
with arguments:
|
||||
?M::3
|
||||
λ m : ℕ, Nat::add_zerol m ⋈ symm (Nat::add_zeror m)
|
||||
λ (n : ℕ) (iH : ?M::3 n) (m : ℕ),
|
||||
λ (n : ℕ) (iH : (?M::3[lift:0:1]) n) (m : ℕ),
|
||||
@trans ℕ
|
||||
(n + 1 + m)
|
||||
(m + n + 1)
|
||||
|
|
20
tests/lean/rw1.lean
Normal file
20
tests/lean/rw1.lean
Normal file
|
@ -0,0 +1,20 @@
|
|||
(*
|
||||
mk_rewrite_rule_set("rw1")
|
||||
add_rewrite_rules("rw1", "and_assoc")
|
||||
add_rewrite_rules("rw1", "and_truer")
|
||||
show_rewrite_rules("rw1")
|
||||
*)
|
||||
|
||||
scope
|
||||
print "new scope"
|
||||
(*
|
||||
add_rewrite_rules("rw1", "or_assoc")
|
||||
enable_rewrite_rules("rw1", "and_assoc", false)
|
||||
show_rewrite_rules("rw1")
|
||||
*)
|
||||
end
|
||||
|
||||
print "after end of scope"
|
||||
(*
|
||||
show_rewrite_rules("rw1")
|
||||
*)
|
11
tests/lean/rw1.lean.expected.out
Normal file
11
tests/lean/rw1.lean.expected.out
Normal file
|
@ -0,0 +1,11 @@
|
|||
Set: pp::colors
|
||||
Set: pp::unicode
|
||||
and_truer : ∀ a : Bool, a ∧ ⊤ ↔ a
|
||||
and_assoc : ∀ a b c : Bool, (a ∧ b) ∧ c ↔ a ∧ b ∧ c
|
||||
new scope
|
||||
or_assoc : ∀ a b c : Bool, (a ∨ b) ∨ c ↔ a ∨ b ∨ c
|
||||
and_truer : ∀ a : Bool, a ∧ ⊤ ↔ a
|
||||
and_assoc [disabled] : ∀ a b c : Bool, (a ∧ b) ∧ c ↔ a ∧ b ∧ c
|
||||
after end of scope
|
||||
and_truer : ∀ a : Bool, a ∧ ⊤ ↔ a
|
||||
and_assoc : ∀ a b c : Bool, (a ∧ b) ∧ c ↔ a ∧ b ∧ c
|
Loading…
Reference in a new issue