feat(library/simplifier): add rewrite rule sets

This commit is contained in:
Leonardo de Moura 2015-06-01 15:15:37 -07:00
parent 780d313686
commit b62e6bb133
10 changed files with 314 additions and 11 deletions

View file

@ -14,6 +14,7 @@ Author: Leonardo de Moura
#include "kernel/quotient/quotient.h"
#include "kernel/hits/hits.h"
#include "library/init_module.h"
#include "library/simplifier/init_module.h"
#include "library/tactic/init_module.h"
#include "library/definitional/init_module.h"
#include "library/print.h"
@ -33,6 +34,7 @@ void initialize() {
initialize_hits_module();
init_default_print_fn();
initialize_library_module();
initialize_simplifier_module();
initialize_tactic_module();
initialize_definitional_module();
initialize_frontend_lean_module();
@ -43,6 +45,7 @@ void finalize() {
finalize_frontend_lean_module();
finalize_definitional_module();
finalize_tactic_module();
finalize_simplifier_module();
finalize_library_module();
finalize_hits_module();
finalize_quotient_module();

View file

@ -38,18 +38,21 @@ public:
bool contains(head_index const & h) const { return m_map.contains(h); }
list<V> const * find(head_index const & h) const { return m_map.find(h); }
void erase(head_index const & h) { m_map.erase(h); }
void erase_entry(head_index const & h, V const & v) {
template<typename P> void filter(head_index const & h, P && p) {
if (auto it = m_map.find(h)) {
auto new_vs = filter(*it, [&](V const & v2) { return v != v2; });
auto new_vs = ::lean::filter(*it, std::forward<P>(p));
if (!new_vs)
m_map.erase(h);
else
m_map.insert(h, new_vs);
}
}
void erase(head_index const & h, V const & v) {
return filter(h, [&](V const & v2) { return v != v2; });
}
void insert(head_index const & h, V const & v) {
if (auto it = m_map.find(h))
m_map.insert(h, cons(v, filter(*it, [&](V const & v2) { return v != v2; })));
m_map.insert(h, cons(v, ::lean::filter(*it, [&](V const & v2) { return v != v2; })));
else
m_map.insert(h, to_list(v));
}

View file

@ -1,2 +1,2 @@
add_library(simplifier ceqv.cpp)
add_library(simplifier ceqv.cpp rewrite_rule_set.cpp init_module.cpp)
target_link_libraries(simplifier ${LEAN_LIBS})

View file

@ -19,7 +19,7 @@ namespace lean {
bool is_ceqv(type_checker & tc, expr e);
/** \brief Auxiliary functional object for creating "conditional equations" */
class to_ceqs_fn {
class to_ceqvs_fn {
environment const & m_env;
type_checker & m_tc;
@ -99,15 +99,15 @@ class to_ceqs_fn {
}
}
public:
to_ceqs_fn(type_checker & tc):m_env(tc.env()), m_tc(tc) {}
to_ceqvs_fn(type_checker & tc):m_env(tc.env()), m_tc(tc) {}
list<expr_pair> operator()(expr const & e, expr const & H) {
return filter(apply(e, H), [&](expr_pair const & p) { return is_ceqv(m_tc, p.first); });
}
};
list<expr_pair> to_ceqs(type_checker & tc, expr const & e, expr const & H) {
return to_ceqs_fn(tc)(e, H);
list<expr_pair> to_ceqvs(type_checker & tc, expr const & e, expr const & H) {
return to_ceqvs_fn(tc)(e, H);
}
bool is_equivalence(environment const & env, expr const & e, expr & rel, expr & lhs, expr & rhs) {

View file

@ -8,11 +8,12 @@ Author: Leonardo de Moura
#include "kernel/type_checker.h"
namespace lean {
bool is_equivalence(environment const & env, expr const & e, expr & rel, expr & lhs, expr & rhs);
/** \brief Given (H : e), return a list of (h_i : e_i) where e_i can be viewed as
a "conditional" rewriting rule. Any equivalence relation registered using
the relation_manager is considered.
*/
list<expr_pair> to_ceqs(type_checker & tc, expr const & e, expr const & H);
list<expr_pair> to_ceqvs(type_checker & tc, expr const & e, expr const & H);
bool is_ceqv(type_checker & tc, expr e);
bool is_permutation_ceqv(environment const & env, expr e);
}

View file

@ -0,0 +1,16 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "library/simplifier/rewrite_rule_set.h"
namespace lean {
void initialize_simplifier_module() {
initialize_rewrite_rule_set();
}
void finalize_simplifier_module() {
finalize_rewrite_rule_set();
}
}

View file

@ -0,0 +1,12 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
namespace lean {
void initialize_simplifier_module();
void finalize_simplifier_module();
}

View file

@ -0,0 +1,190 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <string>
#include "kernel/instantiate.h"
#include "library/scoped_ext.h"
#include "library/expr_pair.h"
#include "library/relation_manager.h"
#include "library/simplifier/ceqv.h"
#include "library/simplifier/rewrite_rule_set.h"
namespace lean {
bool operator==(rewrite_rule const & r1, rewrite_rule const & r2) {
return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs;
}
rewrite_rule::rewrite_rule(name const & id, list<expr> const & metas,
expr const & lhs, expr const & rhs, constraint_seq const & cs,
expr const & proof, bool is_perm):
m_id(id), m_metas(metas), m_lhs(lhs), m_rhs(rhs), m_cs(cs), m_proof(proof),
m_is_permutation(is_perm) {}
rewrite_rule_set::rewrite_rule_set(name const & eqv):
m_eqv(eqv) {}
void rewrite_rule_set::insert(rewrite_rule const & r) {
m_set.insert(r.get_lhs(), r);
}
list<rewrite_rule> const * rewrite_rule_set::find(head_index const & h) const {
return m_set.find(h);
}
void rewrite_rule_set::for_each(std::function<void(rewrite_rule const &)> const & fn) const {
m_set.for_each_entry([&](head_index const &, rewrite_rule const & r) { fn(r); });
}
void rewrite_rule_set::erase(rewrite_rule const & r) {
m_set.erase(r.get_lhs(), r);
}
void rewrite_rule_sets::insert(name const & eqv, rewrite_rule const & r) {
rewrite_rule_set s(eqv);
if (auto const * curr = m_sets.find(eqv)) {
s = *curr;
}
s.insert(r);
m_sets.insert(eqv, s);
}
void rewrite_rule_sets::erase(name const & eqv, rewrite_rule const & r) {
if (auto const * curr = m_sets.find(eqv)) {
rewrite_rule_set s = *curr;
s.erase(r);
if (s.empty())
m_sets.erase(eqv);
else
m_sets.insert(eqv, s);
}
}
void rewrite_rule_sets::get_relations(buffer<name> & rs) const {
m_sets.for_each([&](name const & r, rewrite_rule_set const &) {
rs.push_back(r);
});
}
rewrite_rule_set const * rewrite_rule_sets::find(name const & eqv) const {
return m_sets.find(eqv);
}
list<rewrite_rule> const * rewrite_rule_sets::find(name const & eqv, head_index const & h) const {
if (auto const * s = m_sets.find(eqv))
return s->find(h);
return nullptr;
}
void rewrite_rule_sets::for_each(std::function<void(name const &, rewrite_rule const &)> const & fn) const {
m_sets.for_each([&](name const & eqv, rewrite_rule_set const & s) {
s.for_each([&](rewrite_rule const & r) {
fn(eqv, r);
});
});
}
static name * g_prefix = nullptr;
rewrite_rule_sets add(type_checker & tc, rewrite_rule_sets const & s, name const & id, expr const & e, expr const & h) {
list<expr_pair> ceqvs = to_ceqvs(tc, e, h);
environment const & env = tc.env();
rewrite_rule_sets new_s = s;
for (expr_pair const & p : ceqvs) {
expr new_e = p.first;
expr new_h = p.second;
bool is_perm = is_permutation_ceqv(env, new_e);
buffer<expr> metas;
constraint_seq cs;
unsigned idx = 0;
while (is_pi(new_e)) {
expr mvar = mk_metavar(name(*g_prefix, idx), binding_domain(new_e));
idx++;
metas.push_back(mvar);
// TODO(Leo): type class constraints
new_e = instantiate(binding_body(new_e), mvar);
}
expr rel, lhs, rhs;
if (is_equivalence(env, new_e, rel, lhs, rhs) && is_constant(rel)) {
new_s.insert(const_name(rel), rewrite_rule(id, to_list(metas), lhs, rhs, cs, new_h, is_perm));
}
}
return new_s;
}
rewrite_rule_sets join(rewrite_rule_sets const & s1, rewrite_rule_sets const & s2) {
rewrite_rule_sets new_s1 = s1;
s2.for_each([&](name const & eqv, rewrite_rule const & r) {
new_s1.insert(eqv, r);
});
return new_s1;
}
static name * g_class_name = nullptr;
static std::string * g_key = nullptr;
struct rrs_state {
rewrite_rule_sets m_sets;
name_map<rewrite_rule_sets> m_namespace_cache;
void add(environment const & env, name const & cname) {
// TODO(Leo): universe variables
// TODO(Leo): invalide cache for current namespace
type_checker tc(env);
declaration const & d = env.get(cname);
expr e = d.get_type();
expr h = mk_constant(cname);
m_sets = ::lean::add(tc, m_sets, cname, e, h);
}
};
struct rrs_config {
typedef name entry;
typedef rrs_state state;
static void add_entry(environment const & env, io_state const &, state & s, entry const & e) {
s.add(env, e);
}
static name const & get_class_name() {
return *g_class_name;
}
static std::string const & get_serialization_key() {
return *g_key;
}
static void write_entry(serializer & s, entry const & e) {
s << e;
}
static entry read_entry(deserializer & d) {
entry e; d >> e; return e;
}
static optional<unsigned> get_fingerprint(entry const & e) {
return some(e.hash());
}
};
template class scoped_ext<rrs_config>;
typedef scoped_ext<rrs_config> rrs_ext;
environment add_rewrite_rule(environment const & env, name const & n, bool persistent) {
return rrs_ext::add_entry(env, get_dummy_ios(), n, persistent);
}
rewrite_rule_sets get_rewrite_rule_sets(environment const & env) {
return rrs_ext::get_state(env).m_sets;
}
void initialize_rewrite_rule_set() {
g_prefix = new name(name::mk_internal_unique_name());
g_class_name = new name("rrs");
g_key = new std::string("rrs");
rrs_ext::initialize();
}
void finalize_rewrite_rule_set() {
rrs_ext::finalize();
delete g_key;
delete g_class_name;
delete g_prefix;
}
}

View file

@ -0,0 +1,78 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include "kernel/type_checker.h"
#include "library/head_map.h"
namespace lean {
class rewrite_rule_sets;
class rewrite_rule {
name m_id;
list<expr> m_metas;
expr m_lhs;
expr m_rhs;
constraint_seq m_cs;
expr m_proof;
bool m_is_permutation;
rewrite_rule(name const & id, list<expr> const & metas,
expr const & lhs, expr const & rhs, constraint_seq const & cs,
expr const & proof, bool is_perm);
friend rewrite_rule_sets add(type_checker & tc, rewrite_rule_sets const & s, name const & id,
expr const & e, expr const & h);
public:
name const & get_id() const { return m_id; }
list<expr> const & get_metas() const { return m_metas; }
expr const & get_lhs() const { return m_lhs; }
expr const & get_rhs() const { return m_rhs; }
constraint_seq const & get_cs() const { return m_cs; }
expr const & get_proof() const { return m_proof; }
bool is_perm() const { return m_is_permutation; }
friend bool operator==(rewrite_rule const & r1, rewrite_rule const & r2);
};
bool operator==(rewrite_rule const & r1, rewrite_rule const & r2);
inline bool operator!=(rewrite_rule const & r1, rewrite_rule const & r2) { return !operator==(r1, r2); }
class rewrite_rule_set {
typedef head_map<rewrite_rule> rule_set;
name m_eqv;
rule_set m_set;
public:
rewrite_rule_set() {}
/** \brief Return the equivalence relation associated with this set */
rewrite_rule_set(name const & eqv);
bool empty() const { return m_set.empty(); }
name const & get_eqv() const { return m_eqv; }
void insert(rewrite_rule const & r);
void erase(rewrite_rule const & r);
list<rewrite_rule> const * find(head_index const & h) const;
void for_each(std::function<void(rewrite_rule const &)> const & fn) const;
};
class rewrite_rule_sets {
name_map<rewrite_rule_set> m_sets; // mapping from relation name to rewrite_rule_set
public:
void insert(name const & eqv, rewrite_rule const & r);
void erase(name const & eqv, rewrite_rule const & r);
void get_relations(buffer<name> & rs) const;
rewrite_rule_set const * find(name const & eqv) const;
list<rewrite_rule> const * find(name const & eqv, head_index const & h) const;
void for_each(std::function<void(name const &, rewrite_rule const &)> const & fn) const;
};
rewrite_rule_sets add(type_checker & tc, rewrite_rule_sets const & s, name const & id, expr const & e, expr const & h);
rewrite_rule_sets join(rewrite_rule_sets const & s1, rewrite_rule_sets const & s2);
environment add_rewrite_rule(environment const & env, name const & n, bool persistent = true);
/** \brief Get current rewrite rule sets */
rewrite_rule_sets get_rewrite_rule_sets(environment const & env);
/** \brief Get rewrite rule sets in the given namespace. */
rewrite_rule_sets get_rewrite_rule_set(environment const & env, name const & ns);
void initialize_rewrite_rule_set();
void finalize_rewrite_rule_set();
}

View file

@ -28,9 +28,9 @@ static void tst1() {
lean_assert(map.contains(a));
map.insert(a, b);
lean_assert(map.contains(a));
map.erase_entry(a, b);
map.erase(a, b);
lean_assert(map.contains(a));
map.erase_entry(a, a);
map.erase(a, a);
lean_assert(!map.contains(a));
lean_assert(map.empty());
map.insert(l1, a);