refactor(library/simplifier): cleanup rewrite_rule_set, and use it in the simplifier

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-01-18 20:52:33 -08:00
parent 466285c577
commit 32c5bc25e3
4 changed files with 227 additions and 58 deletions

View file

@ -15,15 +15,10 @@ Author: Leonardo de Moura
#include "library/simplifier/rewrite_rule_set.h"
namespace lean {
struct rewrite_rule_set::rewrite_rule {
name m_id;
expr m_lhs;
expr m_ceq;
expr m_proof;
bool m_is_permutation;
rewrite_rule(name const & id, expr const & lhs, expr const & ceq, expr const & proof, bool is_permutation):
m_id(id), m_lhs(lhs), m_ceq(ceq), m_proof(proof), m_is_permutation(is_permutation) {}
};
rewrite_rule::rewrite_rule(name const & id, expr const & lhs, expr const & rhs, expr const & ceq, expr const & proof,
unsigned num_args, bool is_permutation):
m_id(id), m_lhs(lhs), m_rhs(rhs), m_ceq(ceq), m_proof(proof), m_num_args(num_args), m_is_permutation(is_permutation) {
}
rewrite_rule_set::rewrite_rule_set(ro_environment const & env):m_env(env.to_weak_ref()) {}
rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & other):
@ -36,13 +31,16 @@ void rewrite_rule_set::insert(name const & id, expr const & th, expr const & pro
expr const & ceq = p.first;
expr const & proof = p.second;
bool is_perm = is_permutation_ceq(ceq);
expr lhs = ceq;
while (is_pi(lhs)) {
lhs = abst_body(lhs);
expr eq = ceq;
unsigned num = 0;
while (is_pi(eq)) {
eq = abst_body(eq);
num++;
}
lean_assert(is_equality(lhs));
lhs = arg(lhs, num_args(lhs) - 2);
m_rule_set.emplace_front(id, lhs, ceq, proof, is_perm);
lean_assert(is_equality(eq));
m_rule_set = cons(rewrite_rule(id, arg(eq, num_args(eq) - 2), arg(eq, num_args(eq) - 1),
ceq, proof, num, is_perm),
m_rule_set);
}
}
@ -57,7 +55,7 @@ void rewrite_rule_set::insert(name const & th_name) {
}
bool rewrite_rule_set::enabled(rewrite_rule const & rule) const {
return !m_disabled_rules.contains(rule.m_id);
return !m_disabled_rules.contains(rule.get_id());
}
bool rewrite_rule_set::enabled(name const & id) const {
@ -71,18 +69,19 @@ void rewrite_rule_set::enable(name const & id, bool f) {
m_disabled_rules.insert(id);
}
void rewrite_rule_set::for_each_match_candidate(expr const &, match_fn const & fn) const {
bool rewrite_rule_set::find_match(expr const &, match_fn const & fn) const {
auto l = m_rule_set;
for (auto const & rule : l) {
if (enabled(rule) && fn(rule.m_lhs, rule.m_ceq, rule.m_is_permutation, rule.m_proof))
return;
if (enabled(rule) && fn(rule))
return true;
}
return false;
}
void rewrite_rule_set::for_each(visit_fn const & fn) const {
auto l = m_rule_set;
for (auto const & rule : l) {
fn(rule.m_id, rule.m_ceq, rule.m_proof, enabled(rule));
fn(rule, enabled(rule));
}
}
@ -90,16 +89,16 @@ format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const {
format r;
bool first = true;
unsigned indent = get_pp_indent(opts);
for_each([&](name const & name, expr const & ceq, expr const &, bool enabled) {
for_each([&](rewrite_rule const & rule, bool enabled) {
if (first)
first = false;
else
r += line();
r += format(name);
r += format(rule.get_id());
if (!enabled)
r += format(" [disabled]");
r += format{space(), colon(), space()};
r += nest(indent, fmt(ceq, opts));
r += nest(indent, fmt(rule.get_ceq(), opts));
});
return r;
}

View file

@ -15,11 +15,32 @@ Author: Leonardo de Moura
#include "kernel/formatter.h"
namespace lean {
class rewrite_rule_set;
class rewrite_rule {
friend class rewrite_rule_set;
name m_id;
expr m_lhs;
expr m_rhs;
expr m_ceq;
expr m_proof;
unsigned m_num_args;
bool m_is_permutation;
rewrite_rule(name const & id, expr const & lhs, expr const & rhs, expr const & ceq, expr const & proof,
unsigned num_args, bool is_permutation);
public:
name const & get_id() const { return m_id; }
expr const & get_lhs() const { return m_lhs; }
expr const & get_rhs() const { return m_rhs; }
expr const & get_ceq() const { return m_ceq; }
expr const & get_proof() const { return m_proof; }
unsigned get_num_args() const { return m_num_args; }
bool is_permutation() const { return m_is_permutation; }
};
/**
\brief Actual implementation of the \c rewrite_rule_set class.
*/
class rewrite_rule_set {
struct rewrite_rule;
typedef splay_tree<name, name_quick_cmp> name_set;
ro_environment::weak_ref m_env;
list<rewrite_rule> m_rule_set; // TODO(Leo): use better data-structure, e.g., discrimination trees
@ -51,22 +72,17 @@ public:
/** \brief Enable/disable the conditional rewrite rules tagged with the given identifier. */
void enable(name const & id, bool f);
typedef std::function<bool(expr const &, expr const &, bool is_permutation, expr const &)> match_fn; // NOLINT
typedef std::function<void(name const &, expr const &, expr const &, bool)> visit_fn;
typedef std::function<bool(rewrite_rule const &)> match_fn; // NOLINT
typedef std::function<void(rewrite_rule const &, bool)> visit_fn;
/**
\brief Execute <tt>fn(lhs, ceq, is_perm, proof)</tt> for each (enabled) rule whose the left-hand-side may
\brief Execute <tt>fn(rule)</tt> for each (enabled) rule whose the left-hand-side may
match \c e.
The traversal is interrupted as soon as \c fn returns true.
The redundant argument \c lhs is the left-hand-side of \c ceq.
The redundant argument \c is_perm is true iff \c ceq is a permutation rule.
The argument \c proof is the proof for \c ceq.
*/
void for_each_match_candidate(expr const &, match_fn const & fn) const;
bool find_match(expr const &, match_fn const & fn) const;
/** \brief Execute <tt>fn(id, ceq, proof, enabled)</tt> for each rule in this rule set. */
/** \brief Execute <tt>fn(rule, enabled)</tt> for each rule in this rule set. */
void for_each(visit_fn const & fn) const;
/** \brief Pretty print this rule set. */

View file

@ -92,6 +92,16 @@ class simplifier_fn {
bool m_unfold;
unsigned m_max_steps;
struct match_fn {
simplifier_fn & m_simp;
match_fn(simplifier_fn & s):m_simp(s) {}
bool operator()(rewrite_rule const & rule) const {
return m_simp.match(rule);
}
};
match_fn m_match_fn;
struct result {
expr m_out;
optional<expr> m_proof;
@ -126,20 +136,20 @@ class simplifier_fn {
expr mk_congr1_th(expr const & f_type, expr const & f, expr const & new_f, expr const & a, expr const & Heq_f) {
expr const & A = abst_domain(f_type);
expr const & B = lower_free_vars(abst_body(f_type), 1, 1);
expr B = lower_free_vars(abst_body(f_type), 1, 1);
return ::lean::mk_congr1_th(A, B, f, new_f, a, Heq_f);
}
expr mk_congr2_th(expr const & f_type, expr const & a, expr const & new_a, expr const & f, expr const & Heq_a) {
expr const & A = abst_domain(f_type);
expr const & B = lower_free_vars(abst_body(f_type), 1, 1);
expr B = lower_free_vars(abst_body(f_type), 1, 1);
return ::lean::mk_congr2_th(A, B, a, new_a, f, Heq_a);
}
expr mk_congr_th(expr const & f_type, expr const & f, expr const & new_f, expr const & a, expr const & new_a,
expr const & Heq_f, expr const & Heq_a) {
expr const & A = abst_domain(f_type);
expr const & B = lower_free_vars(abst_body(f_type), 1, 1);
expr B = lower_free_vars(abst_body(f_type), 1, 1);
return ::lean::mk_congr_th(A, B, f, new_f, a, new_a, Heq_f, Heq_a);
}
@ -152,6 +162,65 @@ class simplifier_fn {
f, new_f, a, new_a, Heq_f, Heq_a);
}
result mk_trans_result(expr const & a, result const & b_res, expr const & c, expr const & H_bc) {
if (m_proofs_enabled) {
if (!b_res.m_proof) {
// The proof of a = b is reflexivity
return result(c, H_bc);
} else {
expr const & b = b_res.m_out;
expr new_proof;
bool heq_proof = false;
if (b_res.m_heq_proof) {
expr b_type = infer_type(b);
new_proof = ::lean::mk_htrans_th(infer_type(a), b_type, b_type, /* b and c must have the same type */
a, b, c, *b_res.m_proof, mk_to_heq_th(b_type, b, c, H_bc));
heq_proof = true;
} else {
new_proof = ::lean::mk_trans_th(infer_type(a), a, b, c, *b_res.m_proof, H_bc);
}
return result(c, new_proof, heq_proof);
}
} else {
return result(c);
}
}
result mk_trans_result(expr const & a, result const & b_res, result const & c_res) {
if (m_proofs_enabled) {
if (!b_res.m_proof) {
// the proof of a == b is reflexivity
return c_res;
} else if (!c_res.m_proof) {
// the proof of b == c is reflexivity
return result(c_res.m_out, *b_res.m_proof, b_res.m_heq_proof);
} else {
bool heq_proof = b_res.m_heq_proof || c_res.m_heq_proof;
expr new_proof;
expr const & b = b_res.m_out;
expr const & c = c_res.m_out;
if (heq_proof) {
expr a_type = infer_type(a);
expr b_type = infer_type(b);
expr c_type = infer_type(c);
expr H_ab = *b_res.m_proof;
if (!b_res.m_heq_proof)
H_ab = mk_to_heq_th(a_type, a, b, H_ab);
expr H_bc = *c_res.m_proof;
if (!c_res.m_heq_proof)
H_bc = mk_to_heq_th(b_type, b, c, H_bc);
new_proof = ::lean::mk_htrans_th(a_type, b_type, c_type, a, b, c, H_ab, H_bc);
} else {
new_proof = ::lean::mk_trans_th(infer_type(a), a, b, c, *b_res.m_proof, *c_res.m_proof);
}
return result(c, new_proof, heq_proof);
}
} else {
// proof generation is disabled
return c_res;
}
}
expr mk_app_prefix(unsigned i, expr const & a) {
lean_assert(i > 0);
if (i == 1)
@ -225,9 +294,9 @@ class simplifier_fn {
}
if (!changed) {
return rewrite_app(result(e));
return rewrite(e, result(e));
} else if (!m_proofs_enabled) {
return rewrite_app(result(mk_app(new_args)));
return rewrite(e, result(mk_app(new_args)));
} else {
expr out = mk_app(new_args);
unsigned i = 0;
@ -272,11 +341,76 @@ class simplifier_fn {
pr = mk_congr1_th(f_types[i-1], f, new_f, arg(e, i), pr);
}
}
return rewrite_app(result(out, pr, heq_proof));
return rewrite(e, result(out, pr, heq_proof));
}
}
result rewrite_app(result const & r) {
expr m_target; // temp field
buffer<optional<expr>> m_subst; // temp field
buffer<expr> m_new_args; // temp field
expr m_new_rhs; // temp field
expr m_new_proof; // temp field
void reset_subst(unsigned num_args) {
if (m_subst.size() < num_args) {
m_subst.resize(num_args);
m_new_args.resize(num_args+1);
}
for (unsigned i = 0; i < num_args; i++)
m_subst[i] = none_expr();
}
bool found_all_args(unsigned num_args) {
for (unsigned i = 0; i < num_args; i++) {
if (!m_subst[i])
return false;
m_new_args[i+1] = *m_subst[i];
}
return true;
}
/**
\brief Auxiliary function used by m_match_fn, it tries to match the given rule and
the expression in the temporary field \c m_target.
If it succeeds, then the resultant expression is stored in the temporary field m_new_rhs,
and the proof in m_new_proof (if proofs are enabled).
*/
bool match(rewrite_rule const & rule) {
unsigned num = rule.get_num_args();
reset_subst(num);
if (hop_match(rule.get_lhs(), m_target, m_subst)) {
if (found_all_args(num)) {
// easy case: all arguments found
m_new_rhs = instantiate(rule.get_rhs(), num, m_new_args.data() + 1);
if (m_proofs_enabled) {
if (num > 0) {
m_new_args[0] = rule.get_proof();
m_new_proof = mk_app(m_new_args);
} else {
m_new_proof = rule.get_proof();
}
}
return true;
}
// TODO(Leo): conditional rewriting
}
return false;
}
result rewrite(expr const & e, result const & r) {
m_target = r.m_out;
for (rewrite_rule_set const & rs : m_rule_sets) {
if (rs.find_match(m_target, m_match_fn)) {
// the result is in m_new_rhs and proof at m_new_proof
result new_r1 = mk_trans_result(e, r, m_new_rhs, m_new_proof);
if (m_single_pass) {
return new_r1;
} else {
result new_r2 = simplify(new_r1.m_out);
return mk_trans_result(e, new_r1, new_r2);
}
}
}
return r;
}
@ -302,17 +436,7 @@ class simplifier_fn {
}
}
}
#if 1
if (const_name(e) == "a") {
auto obj = m_env->find_object("a_eq_0");
if (obj) {
expr r = arg(obj->get_type(), 3);
return result(r, mk_constant("a_eq_0"));
}
}
#endif
return result(e);
return rewrite(e, result(e));
}
result simplify_lambda(expr const & e) {
@ -328,11 +452,11 @@ class simplifier_fn {
if (is_eqp(new_body, abst_body(e)))
return result(e);
expr out = mk_lambda(e, new_body);
if (!m_proofs_enabled)
if (!m_proofs_enabled || !res_body.m_proof)
return result(out);
expr body_type = infer_type(abst_body(e));
expr pr = mk_funext_th(abst_domain(e), mk_lambda(e, body_type), e, out,
mk_lambda(e, *(res_body.m_proof)));
mk_lambda(e, *res_body.m_proof));
return result(out, pr);
}
}
@ -350,12 +474,12 @@ class simplifier_fn {
if (is_eqp(new_body, abst_body(e)))
return result(e);
expr out = mk_pi(abst_name(e), abst_domain(e), new_body);
if (!m_proofs_enabled)
if (!m_proofs_enabled || !res_body.m_proof)
return result(out);
expr pr = mk_allext_th(abst_domain(e),
mk_lambda(e, abst_body(e)),
mk_lambda(e, abst_body(out)),
mk_lambda(e, *(res_body.m_proof)));
mk_lambda(e, *res_body.m_proof));
return result(out, pr);
} else {
// if the environment does not contain heq axioms, then we don't simplify Pi's that are not forall's
@ -384,13 +508,13 @@ class simplifier_fn {
m_contextual = get_simplifier_contextual(o);
m_single_pass = get_simplifier_single_pass(o);
m_beta = get_simplifier_beta(o);
m_unfold = true; // get_simplifier_unfold(o);
m_unfold = get_simplifier_unfold(o);
m_max_steps = get_simplifier_max_steps(o);
}
public:
simplifier_fn(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs):
m_env(env), m_tc(env), m_rule_sets(rs, rs + num_rs) {
m_env(env), m_tc(env), m_rule_sets(rs, rs + num_rs), m_match_fn(*this) {
m_has_heq = m_env->imported("heq");
set_options(o);
}

30
tests/lua/simp1.lua Normal file
View file

@ -0,0 +1,30 @@
mk_rewrite_rule_set()
add_rewrite_rules({"Nat", "add_zerol"})
add_rewrite_rules({"Nat", "add_zeror"})
parse_lean_cmds([[
variable f : Nat -> Nat -> Nat
variable g : Nat -> Nat
variable b : Nat
definition a := 1
theorem a_eq_1 : a = 1
:= refl a
definition c := 1
set_opaque a true
axiom f_id (x : Nat) : f x 1 = 2*x
]])
add_rewrite_rules("a_eq_1")
add_rewrite_rules("f_id")
-- set_option({"lean", "pp", "implicit"}, true)
e, pr = simplify(parse_lean('fun x, f (f x (0 + a)) (g (b + 0))'))
print(e)
print(pr)
local env = get_environment()
print(env:type_check(pr))
e, pr = simplify(parse_lean('forall x, let d := a + 1 in f x a >= d'))
print(e)
print(pr)
local env = get_environment()
print(env:type_check(pr))