From b62e6bb133b9f11d1576560e69327ceb3cd176f1 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 1 Jun 2015 15:15:37 -0700 Subject: [PATCH] feat(library/simplifier): add rewrite rule sets --- src/init/init.cpp | 3 + src/library/head_map.h | 9 +- src/library/simplifier/CMakeLists.txt | 2 +- src/library/simplifier/ceqv.cpp | 8 +- src/library/simplifier/ceqv.h | 3 +- src/library/simplifier/init_module.cpp | 16 ++ src/library/simplifier/init_module.h | 12 ++ src/library/simplifier/rewrite_rule_set.cpp | 190 ++++++++++++++++++++ src/library/simplifier/rewrite_rule_set.h | 78 ++++++++ src/tests/library/head_map.cpp | 4 +- 10 files changed, 314 insertions(+), 11 deletions(-) create mode 100644 src/library/simplifier/init_module.cpp create mode 100644 src/library/simplifier/init_module.h create mode 100644 src/library/simplifier/rewrite_rule_set.cpp create mode 100644 src/library/simplifier/rewrite_rule_set.h diff --git a/src/init/init.cpp b/src/init/init.cpp index 115e022cf..3e69bde31 100644 --- a/src/init/init.cpp +++ b/src/init/init.cpp @@ -14,6 +14,7 @@ Author: Leonardo de Moura #include "kernel/quotient/quotient.h" #include "kernel/hits/hits.h" #include "library/init_module.h" +#include "library/simplifier/init_module.h" #include "library/tactic/init_module.h" #include "library/definitional/init_module.h" #include "library/print.h" @@ -33,6 +34,7 @@ void initialize() { initialize_hits_module(); init_default_print_fn(); initialize_library_module(); + initialize_simplifier_module(); initialize_tactic_module(); initialize_definitional_module(); initialize_frontend_lean_module(); @@ -43,6 +45,7 @@ void finalize() { finalize_frontend_lean_module(); finalize_definitional_module(); finalize_tactic_module(); + finalize_simplifier_module(); finalize_library_module(); finalize_hits_module(); finalize_quotient_module(); diff --git a/src/library/head_map.h b/src/library/head_map.h index 0da50d413..18b0a96ac 100644 --- a/src/library/head_map.h +++ b/src/library/head_map.h @@ -38,18 +38,21 @@ public: bool contains(head_index const & h) const { return m_map.contains(h); } list const * find(head_index const & h) const { return m_map.find(h); } void erase(head_index const & h) { m_map.erase(h); } - void erase_entry(head_index const & h, V const & v) { + template void filter(head_index const & h, P && p) { if (auto it = m_map.find(h)) { - auto new_vs = filter(*it, [&](V const & v2) { return v != v2; }); + auto new_vs = ::lean::filter(*it, std::forward

(p)); if (!new_vs) m_map.erase(h); else m_map.insert(h, new_vs); } } + void erase(head_index const & h, V const & v) { + return filter(h, [&](V const & v2) { return v != v2; }); + } void insert(head_index const & h, V const & v) { if (auto it = m_map.find(h)) - m_map.insert(h, cons(v, filter(*it, [&](V const & v2) { return v != v2; }))); + m_map.insert(h, cons(v, ::lean::filter(*it, [&](V const & v2) { return v != v2; }))); else m_map.insert(h, to_list(v)); } diff --git a/src/library/simplifier/CMakeLists.txt b/src/library/simplifier/CMakeLists.txt index db5dd1cf4..dde092baf 100644 --- a/src/library/simplifier/CMakeLists.txt +++ b/src/library/simplifier/CMakeLists.txt @@ -1,2 +1,2 @@ -add_library(simplifier ceqv.cpp) +add_library(simplifier ceqv.cpp rewrite_rule_set.cpp init_module.cpp) target_link_libraries(simplifier ${LEAN_LIBS}) diff --git a/src/library/simplifier/ceqv.cpp b/src/library/simplifier/ceqv.cpp index fe831351f..5544ba942 100644 --- a/src/library/simplifier/ceqv.cpp +++ b/src/library/simplifier/ceqv.cpp @@ -19,7 +19,7 @@ namespace lean { bool is_ceqv(type_checker & tc, expr e); /** \brief Auxiliary functional object for creating "conditional equations" */ -class to_ceqs_fn { +class to_ceqvs_fn { environment const & m_env; type_checker & m_tc; @@ -99,15 +99,15 @@ class to_ceqs_fn { } } public: - to_ceqs_fn(type_checker & tc):m_env(tc.env()), m_tc(tc) {} + to_ceqvs_fn(type_checker & tc):m_env(tc.env()), m_tc(tc) {} list operator()(expr const & e, expr const & H) { return filter(apply(e, H), [&](expr_pair const & p) { return is_ceqv(m_tc, p.first); }); } }; -list to_ceqs(type_checker & tc, expr const & e, expr const & H) { - return to_ceqs_fn(tc)(e, H); +list to_ceqvs(type_checker & tc, expr const & e, expr const & H) { + return to_ceqvs_fn(tc)(e, H); } bool is_equivalence(environment const & env, expr const & e, expr & rel, expr & lhs, expr & rhs) { diff --git a/src/library/simplifier/ceqv.h b/src/library/simplifier/ceqv.h index 051a821ba..cac61d38b 100644 --- a/src/library/simplifier/ceqv.h +++ b/src/library/simplifier/ceqv.h @@ -8,11 +8,12 @@ Author: Leonardo de Moura #include "kernel/type_checker.h" namespace lean { +bool is_equivalence(environment const & env, expr const & e, expr & rel, expr & lhs, expr & rhs); /** \brief Given (H : e), return a list of (h_i : e_i) where e_i can be viewed as a "conditional" rewriting rule. Any equivalence relation registered using the relation_manager is considered. */ -list to_ceqs(type_checker & tc, expr const & e, expr const & H); +list to_ceqvs(type_checker & tc, expr const & e, expr const & H); bool is_ceqv(type_checker & tc, expr e); bool is_permutation_ceqv(environment const & env, expr e); } diff --git a/src/library/simplifier/init_module.cpp b/src/library/simplifier/init_module.cpp new file mode 100644 index 000000000..3c4397154 --- /dev/null +++ b/src/library/simplifier/init_module.cpp @@ -0,0 +1,16 @@ +/* +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include "library/simplifier/rewrite_rule_set.h" + +namespace lean { +void initialize_simplifier_module() { + initialize_rewrite_rule_set(); +} +void finalize_simplifier_module() { + finalize_rewrite_rule_set(); +} +} diff --git a/src/library/simplifier/init_module.h b/src/library/simplifier/init_module.h new file mode 100644 index 000000000..bb1b950cf --- /dev/null +++ b/src/library/simplifier/init_module.h @@ -0,0 +1,12 @@ +/* +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once + +namespace lean { +void initialize_simplifier_module(); +void finalize_simplifier_module(); +} diff --git a/src/library/simplifier/rewrite_rule_set.cpp b/src/library/simplifier/rewrite_rule_set.cpp new file mode 100644 index 000000000..d4f43d251 --- /dev/null +++ b/src/library/simplifier/rewrite_rule_set.cpp @@ -0,0 +1,190 @@ +/* +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include +#include "kernel/instantiate.h" +#include "library/scoped_ext.h" +#include "library/expr_pair.h" +#include "library/relation_manager.h" +#include "library/simplifier/ceqv.h" +#include "library/simplifier/rewrite_rule_set.h" + +namespace lean { +bool operator==(rewrite_rule const & r1, rewrite_rule const & r2) { + return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs; +} + +rewrite_rule::rewrite_rule(name const & id, list const & metas, + expr const & lhs, expr const & rhs, constraint_seq const & cs, + expr const & proof, bool is_perm): + m_id(id), m_metas(metas), m_lhs(lhs), m_rhs(rhs), m_cs(cs), m_proof(proof), + m_is_permutation(is_perm) {} + +rewrite_rule_set::rewrite_rule_set(name const & eqv): + m_eqv(eqv) {} + +void rewrite_rule_set::insert(rewrite_rule const & r) { + m_set.insert(r.get_lhs(), r); +} + +list const * rewrite_rule_set::find(head_index const & h) const { + return m_set.find(h); +} + +void rewrite_rule_set::for_each(std::function const & fn) const { + m_set.for_each_entry([&](head_index const &, rewrite_rule const & r) { fn(r); }); +} + +void rewrite_rule_set::erase(rewrite_rule const & r) { + m_set.erase(r.get_lhs(), r); +} + +void rewrite_rule_sets::insert(name const & eqv, rewrite_rule const & r) { + rewrite_rule_set s(eqv); + if (auto const * curr = m_sets.find(eqv)) { + s = *curr; + } + s.insert(r); + m_sets.insert(eqv, s); +} + +void rewrite_rule_sets::erase(name const & eqv, rewrite_rule const & r) { + if (auto const * curr = m_sets.find(eqv)) { + rewrite_rule_set s = *curr; + s.erase(r); + if (s.empty()) + m_sets.erase(eqv); + else + m_sets.insert(eqv, s); + } +} + +void rewrite_rule_sets::get_relations(buffer & rs) const { + m_sets.for_each([&](name const & r, rewrite_rule_set const &) { + rs.push_back(r); + }); +} + +rewrite_rule_set const * rewrite_rule_sets::find(name const & eqv) const { + return m_sets.find(eqv); +} + +list const * rewrite_rule_sets::find(name const & eqv, head_index const & h) const { + if (auto const * s = m_sets.find(eqv)) + return s->find(h); + return nullptr; +} + +void rewrite_rule_sets::for_each(std::function const & fn) const { + m_sets.for_each([&](name const & eqv, rewrite_rule_set const & s) { + s.for_each([&](rewrite_rule const & r) { + fn(eqv, r); + }); + }); +} + +static name * g_prefix = nullptr; + +rewrite_rule_sets add(type_checker & tc, rewrite_rule_sets const & s, name const & id, expr const & e, expr const & h) { + list ceqvs = to_ceqvs(tc, e, h); + environment const & env = tc.env(); + rewrite_rule_sets new_s = s; + for (expr_pair const & p : ceqvs) { + expr new_e = p.first; + expr new_h = p.second; + bool is_perm = is_permutation_ceqv(env, new_e); + buffer metas; + constraint_seq cs; + unsigned idx = 0; + while (is_pi(new_e)) { + expr mvar = mk_metavar(name(*g_prefix, idx), binding_domain(new_e)); + idx++; + metas.push_back(mvar); + // TODO(Leo): type class constraints + new_e = instantiate(binding_body(new_e), mvar); + } + expr rel, lhs, rhs; + if (is_equivalence(env, new_e, rel, lhs, rhs) && is_constant(rel)) { + new_s.insert(const_name(rel), rewrite_rule(id, to_list(metas), lhs, rhs, cs, new_h, is_perm)); + } + } + return new_s; +} + +rewrite_rule_sets join(rewrite_rule_sets const & s1, rewrite_rule_sets const & s2) { + rewrite_rule_sets new_s1 = s1; + s2.for_each([&](name const & eqv, rewrite_rule const & r) { + new_s1.insert(eqv, r); + }); + return new_s1; +} + +static name * g_class_name = nullptr; +static std::string * g_key = nullptr; + +struct rrs_state { + rewrite_rule_sets m_sets; + name_map m_namespace_cache; + + void add(environment const & env, name const & cname) { + // TODO(Leo): universe variables + // TODO(Leo): invalide cache for current namespace + type_checker tc(env); + declaration const & d = env.get(cname); + expr e = d.get_type(); + expr h = mk_constant(cname); + m_sets = ::lean::add(tc, m_sets, cname, e, h); + } +}; + +struct rrs_config { + typedef name entry; + typedef rrs_state state; + static void add_entry(environment const & env, io_state const &, state & s, entry const & e) { + s.add(env, e); + } + static name const & get_class_name() { + return *g_class_name; + } + static std::string const & get_serialization_key() { + return *g_key; + } + static void write_entry(serializer & s, entry const & e) { + s << e; + } + static entry read_entry(deserializer & d) { + entry e; d >> e; return e; + } + static optional get_fingerprint(entry const & e) { + return some(e.hash()); + } +}; + +template class scoped_ext; +typedef scoped_ext rrs_ext; + +environment add_rewrite_rule(environment const & env, name const & n, bool persistent) { + return rrs_ext::add_entry(env, get_dummy_ios(), n, persistent); +} + +rewrite_rule_sets get_rewrite_rule_sets(environment const & env) { + return rrs_ext::get_state(env).m_sets; +} + +void initialize_rewrite_rule_set() { + g_prefix = new name(name::mk_internal_unique_name()); + g_class_name = new name("rrs"); + g_key = new std::string("rrs"); + rrs_ext::initialize(); +} + +void finalize_rewrite_rule_set() { + rrs_ext::finalize(); + delete g_key; + delete g_class_name; + delete g_prefix; +} +} diff --git a/src/library/simplifier/rewrite_rule_set.h b/src/library/simplifier/rewrite_rule_set.h new file mode 100644 index 000000000..df5c712f5 --- /dev/null +++ b/src/library/simplifier/rewrite_rule_set.h @@ -0,0 +1,78 @@ +/* +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include "kernel/type_checker.h" +#include "library/head_map.h" + +namespace lean { +class rewrite_rule_sets; + +class rewrite_rule { + name m_id; + list m_metas; + expr m_lhs; + expr m_rhs; + constraint_seq m_cs; + expr m_proof; + bool m_is_permutation; + rewrite_rule(name const & id, list const & metas, + expr const & lhs, expr const & rhs, constraint_seq const & cs, + expr const & proof, bool is_perm); + friend rewrite_rule_sets add(type_checker & tc, rewrite_rule_sets const & s, name const & id, + expr const & e, expr const & h); +public: + name const & get_id() const { return m_id; } + list const & get_metas() const { return m_metas; } + expr const & get_lhs() const { return m_lhs; } + expr const & get_rhs() const { return m_rhs; } + constraint_seq const & get_cs() const { return m_cs; } + expr const & get_proof() const { return m_proof; } + bool is_perm() const { return m_is_permutation; } + friend bool operator==(rewrite_rule const & r1, rewrite_rule const & r2); +}; + +bool operator==(rewrite_rule const & r1, rewrite_rule const & r2); +inline bool operator!=(rewrite_rule const & r1, rewrite_rule const & r2) { return !operator==(r1, r2); } + +class rewrite_rule_set { + typedef head_map rule_set; + name m_eqv; + rule_set m_set; +public: + rewrite_rule_set() {} + /** \brief Return the equivalence relation associated with this set */ + rewrite_rule_set(name const & eqv); + bool empty() const { return m_set.empty(); } + name const & get_eqv() const { return m_eqv; } + void insert(rewrite_rule const & r); + void erase(rewrite_rule const & r); + list const * find(head_index const & h) const; + void for_each(std::function const & fn) const; +}; + +class rewrite_rule_sets { + name_map m_sets; // mapping from relation name to rewrite_rule_set +public: + void insert(name const & eqv, rewrite_rule const & r); + void erase(name const & eqv, rewrite_rule const & r); + void get_relations(buffer & rs) const; + rewrite_rule_set const * find(name const & eqv) const; + list const * find(name const & eqv, head_index const & h) const; + void for_each(std::function const & fn) const; +}; + +rewrite_rule_sets add(type_checker & tc, rewrite_rule_sets const & s, name const & id, expr const & e, expr const & h); +rewrite_rule_sets join(rewrite_rule_sets const & s1, rewrite_rule_sets const & s2); + +environment add_rewrite_rule(environment const & env, name const & n, bool persistent = true); +/** \brief Get current rewrite rule sets */ +rewrite_rule_sets get_rewrite_rule_sets(environment const & env); +/** \brief Get rewrite rule sets in the given namespace. */ +rewrite_rule_sets get_rewrite_rule_set(environment const & env, name const & ns); +void initialize_rewrite_rule_set(); +void finalize_rewrite_rule_set(); +} diff --git a/src/tests/library/head_map.cpp b/src/tests/library/head_map.cpp index cf71bdb28..d41233575 100644 --- a/src/tests/library/head_map.cpp +++ b/src/tests/library/head_map.cpp @@ -28,9 +28,9 @@ static void tst1() { lean_assert(map.contains(a)); map.insert(a, b); lean_assert(map.contains(a)); - map.erase_entry(a, b); + map.erase(a, b); lean_assert(map.contains(a)); - map.erase_entry(a, a); + map.erase(a, a); lean_assert(!map.contains(a)); lean_assert(map.empty()); map.insert(l1, a);