diff --git a/src/library/simplifier/register_module.h b/src/library/simplifier/register_module.h index f66488490..5cf681caa 100644 --- a/src/library/simplifier/register_module.h +++ b/src/library/simplifier/register_module.h @@ -7,11 +7,13 @@ Author: Leonardo de Moura #pragma once #include "util/script_state.h" #include "library/simplifier/ceq.h" +#include "library/simplifier/rewrite_rule_set.h" #include "library/simplifier/simplifier.h" namespace lean { inline void open_simplifier_module(lua_State * L) { open_ceq(L); + open_rewrite_rule_set(L); open_simplifier(L); } inline void register_simplifier_module() { diff --git a/src/library/simplifier/rewrite_rule_set.cpp b/src/library/simplifier/rewrite_rule_set.cpp index 7174f87d4..9e15bc709 100644 --- a/src/library/simplifier/rewrite_rule_set.cpp +++ b/src/library/simplifier/rewrite_rule_set.cpp @@ -8,7 +8,9 @@ Author: Leonardo de Moura #include "util/list_fn.h" #include "util/sstream.h" #include "kernel/environment.h" +#include "library/io_state_stream.h" #include "library/equality.h" +#include "library/kernel_bindings.h" #include "library/simplifier/ceq.h" #include "library/simplifier/rewrite_rule_set.h" @@ -108,14 +110,218 @@ public: return r; } }; +rewrite_rule_set::rewrite_rule_set(imp * ptr):m_ptr(ptr) {} rewrite_rule_set::rewrite_rule_set(ro_environment const & env):m_ptr(new imp(env)) {} -rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & rs):m_ptr(new imp(*(rs.m_ptr))) {} +rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & rs):m_ptr(new imp(*rs.m_ptr)) {} rewrite_rule_set::~rewrite_rule_set() {} void rewrite_rule_set::insert(name const & id, expr const & th, expr const & proof) { m_ptr->insert(id, th, proof); } void rewrite_rule_set::insert(name const & th_name) { m_ptr->insert(th_name); } bool rewrite_rule_set::enabled(name const & id) const { return m_ptr->enabled(id); } void rewrite_rule_set::enable(name const & id, bool f) { m_ptr->enable(id, f); } -void rewrite_rule_set::for_each_match_candidate(expr const & e, match_fn const & fn) { m_ptr->for_each_match_candidate(e, fn); } -void rewrite_rule_set::for_each(visit_fn const & fn) { m_ptr->for_each(fn); } -format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const { return m_ptr->pp(fmt, opts); } +void rewrite_rule_set::for_each_match_candidate(expr const & e, match_fn const & fn) const { + m_ptr->for_each_match_candidate(e, fn); +} +void rewrite_rule_set::for_each(visit_fn const & fn) const { m_ptr->for_each(fn); } +format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const { return m_ptr->pp(fmt, opts); } + +class mk_rewrite_rule_set_obj : public neutral_object_cell { + name m_rule_set_id; +public: + mk_rewrite_rule_set_obj(name const & id):m_rule_set_id(id) {} + virtual ~mk_rewrite_rule_set_obj() {} + virtual char const * keyword() const { return "mk_rewrite_rule_set"; } + virtual void write(serializer & s) const { s << "mk_rrs" << m_rule_set_id; } +}; +static void read_rrs(environment const & env, io_state const &, deserializer & d) { + name n = read_name(d); + mk_rewrite_rule_set(env, n); +} +static object_cell::register_deserializer_fn mk_rrs_ds("mk_rrs", read_rrs); + +class add_rewrite_rules_obj : public neutral_object_cell { + name m_rule_set_id; + name m_th_name; +public: + add_rewrite_rules_obj(name const & rsid, name const & th_name):m_rule_set_id(rsid), m_th_name(th_name) {} + virtual ~add_rewrite_rules_obj() {} + virtual char const * keyword() const { return "add_rewrite_rules"; } + virtual void write(serializer & s) const { s << "add_rr" << m_rule_set_id << m_th_name; } +}; +static void read_arr(environment const & env, io_state const &, deserializer & d) { + name rsid = read_name(d); + name th = read_name(d); + add_rewrite_rules(env, rsid, th); +} +static object_cell::register_deserializer_fn add_rr_ds("add_rr", read_arr); + +class enable_rewrite_rules_obj : public neutral_object_cell { + name m_rule_set_id; + name m_rule_id; + bool m_flag; +public: + enable_rewrite_rules_obj(name const & rsid, name const & id, bool flag):m_rule_set_id(rsid), m_rule_id(id), m_flag(flag) {} + virtual ~enable_rewrite_rules_obj() {} + virtual char const * keyword() const { return "enable_rewrite_rules_obj"; } + virtual void write(serializer & s) const { s << "enable_rr" << m_rule_set_id << m_rule_id << m_flag; } +}; +static void read_enable_rr(environment const & env, io_state const &, deserializer & d) { + name rsid = read_name(d); + name id = read_name(d); + bool flag = d.read_bool(); + enable_rewrite_rules(env, rsid, id, flag); +} +static object_cell::register_deserializer_fn enable_rr_ds("enable_rr", read_enable_rr); + +/** + \brief Extension for managing rewrite rule sets. +*/ +struct rewrite_rule_set_extension : public environment_extension { + name_map m_rule_sets; + + rewrite_rule_set_extension const * get_parent() const { + return environment_extension::get_parent(); + } + + rewrite_rule_set const * find_ro_core(name const & rule_set_id) const { + auto it = m_rule_sets.find(rule_set_id); + if (it != m_rule_sets.end()) { + return &(it->second); + } + auto p = get_parent(); + if (p) { + return p->find_ro_core(rule_set_id); + } else { + return nullptr; + } + } + + rewrite_rule_set const & find_ro(name const & rule_set_id) const { + auto rs = find_ro_core(rule_set_id); + if (rs) + return *rs; + throw exception(sstream() << "environment does not contain a rewrite rule set named '" << rule_set_id << "'"); + } + + rewrite_rule_set & find_rw(name const & rule_set_id) { + auto it = m_rule_sets.find(rule_set_id); + if (it != m_rule_sets.end()) + return it->second; + auto p = get_parent(); + if (p) { + auto const & rs = p->find_ro(rule_set_id); + m_rule_sets.insert(mk_pair(rule_set_id, rewrite_rule_set(rs))); + return m_rule_sets.find(rule_set_id)->second; + } + throw exception(sstream() << "environment does not contain a rewrite rule set named '" << rule_set_id << "'"); + } + + void mk_rewrite_rule_set(environment const & env, name const & rule_set_id) { + if (find_ro_core(rule_set_id)) + throw exception(sstream() << "environment already contains a rewrite rule set named '" << rule_set_id << "'"); + m_rule_sets.insert(mk_pair(rule_set_id, rewrite_rule_set(env))); + env->add_neutral_object(new mk_rewrite_rule_set_obj(rule_set_id)); + } + + void add_rewrite_rules(environment const & env, name const & rule_set_id, name const & th_name) { + auto & rs = find_rw(rule_set_id); + rs.insert(th_name); + env->add_neutral_object(new add_rewrite_rules_obj(rule_set_id, th_name)); + } + + void enable_rewrite_rules(environment const & env, name const & rule_set_id, name const & rule_id, bool flag) { + auto & rs = find_rw(rule_set_id); + rs.enable(rule_id, flag); + env->add_neutral_object(new enable_rewrite_rules_obj(rule_set_id, rule_id, flag)); + } + + rewrite_rule_set const & get_rewrite_rule_set(name const & rule_set_id) const { + return find_ro(rule_set_id); + } +}; + +struct rewrite_rule_set_extension_initializer { + unsigned m_extid; + rewrite_rule_set_extension_initializer() { + m_extid = environment_cell::register_extension([](){ + return std::unique_ptr(new rewrite_rule_set_extension()); + }); + } +}; + +static rewrite_rule_set_extension_initializer g_rewrite_rule_set_extension_initializer; + +static rewrite_rule_set_extension const & to_ext(ro_environment const & env) { + return env->get_extension(g_rewrite_rule_set_extension_initializer.m_extid); +} + +static rewrite_rule_set_extension & to_ext(environment const & env) { + return env->get_extension(g_rewrite_rule_set_extension_initializer.m_extid); +} + +void mk_rewrite_rule_set(environment const & env, name const & rule_set_id) { + to_ext(env).mk_rewrite_rule_set(env, rule_set_id); +} + +void add_rewrite_rules(environment const & env, name const & rule_set_id, name const & th_name) { + to_ext(env).add_rewrite_rules(env, rule_set_id, th_name); +} + +void enable_rewrite_rules(environment const & env, name const & rule_set_id, name const & rule_id, bool flag) { + to_ext(env).enable_rewrite_rules(env, rule_set_id, rule_id, flag); +} + +rewrite_rule_set const & get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id) { + return to_ext(env).get_rewrite_rule_set(rule_set_id); +} + +static int mk_rewrite_rule_set(lua_State * L) { + int nargs = lua_gettop(L); + if (nargs == 1) + mk_rewrite_rule_set(rw_shared_environment(L), to_name_ext(L, 1)); + else + mk_rewrite_rule_set(rw_shared_environment(L, 2), to_name_ext(L, 1)); + return 0; +} + +static int add_rewrite_rules(lua_State * L) { + int nargs = lua_gettop(L); + if (nargs == 2) + add_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), to_name_ext(L, 2)); + else + add_rewrite_rules(rw_shared_environment(L, 3), to_name_ext(L, 1), to_name_ext(L, 2)); + return 0; +} + +static int enable_rewrite_rules(lua_State * L) { + int nargs = lua_gettop(L); + if (nargs == 3) + enable_rewrite_rules(rw_shared_environment(L), to_name_ext(L, 1), to_name_ext(L, 2), lua_toboolean(L, 3)); + else + enable_rewrite_rules(rw_shared_environment(L, 4), to_name_ext(L, 1), to_name_ext(L, 2), lua_toboolean(L, 3)); + return 0; +} + +static int show_rewrite_rules(lua_State * L) { + int nargs = lua_gettop(L); + formatter fmt = get_global_formatter(L); + options opts = get_global_options(L); + format r; + if (nargs == 1) + r = get_rewrite_rule_set(ro_shared_environment(L), to_name_ext(L, 1)).pp(fmt, opts); + else + r = get_rewrite_rule_set(ro_shared_environment(L, 2), to_name_ext(L, 1)).pp(fmt, opts); + io_state * ios = get_io_state(L); + if (ios) + regular(*ios) << mk_pair(r, opts) << endl; + else + std::cout << mk_pair(r, opts) << "\n"; + return 0; +} + +void open_rewrite_rule_set(lua_State * L) { + SET_GLOBAL_FUN(mk_rewrite_rule_set, "mk_rewrite_rule_set"); + SET_GLOBAL_FUN(add_rewrite_rules, "add_rewrite_rules"); + SET_GLOBAL_FUN(enable_rewrite_rules, "enable_rewrite_rules"); + SET_GLOBAL_FUN(show_rewrite_rules, "show_rewrite_rules"); +} } diff --git a/src/library/simplifier/rewrite_rule_set.h b/src/library/simplifier/rewrite_rule_set.h index da35158ee..e0c6381ac 100644 --- a/src/library/simplifier/rewrite_rule_set.h +++ b/src/library/simplifier/rewrite_rule_set.h @@ -7,6 +7,7 @@ Author: Leonardo de Moura #pragma once #include #include +#include "util/lua.h" #include "kernel/environment.h" #include "kernel/formatter.h" @@ -14,6 +15,7 @@ namespace lean { class rewrite_rule_set { class imp; std::unique_ptr m_ptr; + rewrite_rule_set(imp * ptr); public: rewrite_rule_set(ro_environment const & env); rewrite_rule_set(rewrite_rule_set const & rs); @@ -31,16 +33,12 @@ public: */ void insert(name const & th_name); - /** - \brief Enable/disable the conditional rewrite rules tagged with the given identifier. - */ + /** \brief Enable/disable the conditional rewrite rules tagged with the given identifier. */ void enable(name const & id, bool flag); - /** - \brief Return true iff the conditional rewrite rules tagged with the given id are enabled. - */ + /** \brief Return true iff the conditional rewrite rules tagged with the given id are enabled. */ bool enabled(name const & id) const; - typedef std::function match_fn; + typedef std::function match_fn; // NOLINT typedef std::function visit_fn; /** @@ -53,16 +51,41 @@ public: The argument \c proof is the proof for \c ceq. */ - void for_each_match_candidate(expr const & e, match_fn const & fn); + void for_each_match_candidate(expr const & e, match_fn const & fn) const; + /** \brief Execute fn(id, ceq, proof, enabled) for each rule in this rule set. */ + void for_each(visit_fn const & fn) const; - /** - \brief Execute fn(id, ceq, proof, enabled) for each rule in this rule set. - */ - void for_each(visit_fn const & fn); - - /** - \brief Pretty print this rule set. - */ + /** \brief Pretty print this rule set. */ format pp(formatter const & fmt, options const & opts) const; }; + +/** + \brief Create a rewrite rule set inside the given environment. + + \remark The rule set is saved when the environment is serialized. + \remark This procedure throws an exception if the environment already contains a rule set named \c rule_set_id. +*/ +void mk_rewrite_rule_set(environment const & env, name const & rule_set_id); +/** + \brief Convert the theorem named \c th_name into conditional rewrite rules + and insert them in the rule set named \c rule_set_id in the given environment. + + \remark This procedure throws an exception if the environment does not have a theorem/axiom named \c th_name. + \remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id. +*/ +void add_rewrite_rules(environment const & env, name const & rule_set_id, name const & th_name); +/** + \brief Enable/disable the rewrite rules whose id is \c rule_id in the given rule set. + + \remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id. +*/ +void enable_rewrite_rules(environment const & env, name const & rule_set_id, name const & rule_id, bool flag); +/** + \brief Return the rule set name \c rule_set_id in the given environment. + + \remark This procedure throws an exception if the environment does not have a rule set named \c rule_set_id. +*/ +rewrite_rule_set const & get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id); + +void open_rewrite_rule_set(lua_State * L); } diff --git a/tests/lean/find.lean.expected.out b/tests/lean/find.lean.expected.out index e403cbf09..875c5f8ed 100644 --- a/tests/lean/find.lean.expected.out +++ b/tests/lean/find.lean.expected.out @@ -1,7 +1,7 @@ Set: pp::colors Set: pp::unicode Imported 'find' -theorem congr1 {A : TypeU} {B : A → TypeU} {f g : ∀ x : A, B x} (a : A) (H : f = g) : f a = g a +theorem congr1 {A B : TypeU} {f g : A → B} (a : A) (H : f = g) : f a = g a theorem congr2 {A B : TypeU} {a b : A} (f : A → B) (H : a = b) : f a = f b theorem congr {A B : TypeU} {f g : A → B} {a b : A} (H1 : f = g) (H2 : a = b) : f a = g b find.lean:3:0: error: executing external script (/home/leo/projects/lean/build/debug/shell/find.lua:24), no object name in the environment matches the regular expression 'foo' diff --git a/tests/lean/induction2.lean.expected.out b/tests/lean/induction2.lean.expected.out index 5a1342406..ea82d9d96 100644 --- a/tests/lean/induction2.lean.expected.out +++ b/tests/lean/induction2.lean.expected.out @@ -10,7 +10,7 @@ Failed to solve with arguments: ?M::3 λ m : ℕ, Nat::add_zerol m ⋈ symm (Nat::add_zeror m) - λ (n : ℕ) (iH : ?M::3 n) (m : ℕ), + λ (n : ℕ) (iH : (?M::3[lift:0:1]) n) (m : ℕ), @trans ℕ (n + 1 + m) (m + n + 1) diff --git a/tests/lean/rw1.lean b/tests/lean/rw1.lean new file mode 100644 index 000000000..492a660ed --- /dev/null +++ b/tests/lean/rw1.lean @@ -0,0 +1,20 @@ +(* +mk_rewrite_rule_set("rw1") +add_rewrite_rules("rw1", "and_assoc") +add_rewrite_rules("rw1", "and_truer") +show_rewrite_rules("rw1") +*) + +scope + print "new scope" + (* + add_rewrite_rules("rw1", "or_assoc") + enable_rewrite_rules("rw1", "and_assoc", false) + show_rewrite_rules("rw1") + *) +end + +print "after end of scope" +(* + show_rewrite_rules("rw1") +*) diff --git a/tests/lean/rw1.lean.expected.out b/tests/lean/rw1.lean.expected.out new file mode 100644 index 000000000..0e5019e31 --- /dev/null +++ b/tests/lean/rw1.lean.expected.out @@ -0,0 +1,11 @@ + Set: pp::colors + Set: pp::unicode +and_truer : ∀ a : Bool, a ∧ ⊤ ↔ a +and_assoc : ∀ a b c : Bool, (a ∧ b) ∧ c ↔ a ∧ b ∧ c +new scope +or_assoc : ∀ a b c : Bool, (a ∨ b) ∨ c ↔ a ∨ b ∨ c +and_truer : ∀ a : Bool, a ∧ ⊤ ↔ a +and_assoc [disabled] : ∀ a b c : Bool, (a ∧ b) ∧ c ↔ a ∧ b ∧ c +after end of scope +and_truer : ∀ a : Bool, a ∧ ⊤ ↔ a +and_assoc : ∀ a b c : Bool, (a ∧ b) ∧ c ↔ a ∧ b ∧ c