refactor(library/simplifier): cleanup rewrite_rule_set, and use it in the simplifier

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-01-18 20:52:33 -08:00
parent 466285c577
commit 32c5bc25e3
4 changed files with 227 additions and 58 deletions

View file

@ -15,15 +15,10 @@ Author: Leonardo de Moura
#include "library/simplifier/rewrite_rule_set.h" #include "library/simplifier/rewrite_rule_set.h"
namespace lean { namespace lean {
struct rewrite_rule_set::rewrite_rule { rewrite_rule::rewrite_rule(name const & id, expr const & lhs, expr const & rhs, expr const & ceq, expr const & proof,
name m_id; unsigned num_args, bool is_permutation):
expr m_lhs; m_id(id), m_lhs(lhs), m_rhs(rhs), m_ceq(ceq), m_proof(proof), m_num_args(num_args), m_is_permutation(is_permutation) {
expr m_ceq; }
expr m_proof;
bool m_is_permutation;
rewrite_rule(name const & id, expr const & lhs, expr const & ceq, expr const & proof, bool is_permutation):
m_id(id), m_lhs(lhs), m_ceq(ceq), m_proof(proof), m_is_permutation(is_permutation) {}
};
rewrite_rule_set::rewrite_rule_set(ro_environment const & env):m_env(env.to_weak_ref()) {} rewrite_rule_set::rewrite_rule_set(ro_environment const & env):m_env(env.to_weak_ref()) {}
rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & other): rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & other):
@ -36,13 +31,16 @@ void rewrite_rule_set::insert(name const & id, expr const & th, expr const & pro
expr const & ceq = p.first; expr const & ceq = p.first;
expr const & proof = p.second; expr const & proof = p.second;
bool is_perm = is_permutation_ceq(ceq); bool is_perm = is_permutation_ceq(ceq);
expr lhs = ceq; expr eq = ceq;
while (is_pi(lhs)) { unsigned num = 0;
lhs = abst_body(lhs); while (is_pi(eq)) {
eq = abst_body(eq);
num++;
} }
lean_assert(is_equality(lhs)); lean_assert(is_equality(eq));
lhs = arg(lhs, num_args(lhs) - 2); m_rule_set = cons(rewrite_rule(id, arg(eq, num_args(eq) - 2), arg(eq, num_args(eq) - 1),
m_rule_set.emplace_front(id, lhs, ceq, proof, is_perm); ceq, proof, num, is_perm),
m_rule_set);
} }
} }
@ -57,7 +55,7 @@ void rewrite_rule_set::insert(name const & th_name) {
} }
bool rewrite_rule_set::enabled(rewrite_rule const & rule) const { bool rewrite_rule_set::enabled(rewrite_rule const & rule) const {
return !m_disabled_rules.contains(rule.m_id); return !m_disabled_rules.contains(rule.get_id());
} }
bool rewrite_rule_set::enabled(name const & id) const { bool rewrite_rule_set::enabled(name const & id) const {
@ -71,18 +69,19 @@ void rewrite_rule_set::enable(name const & id, bool f) {
m_disabled_rules.insert(id); m_disabled_rules.insert(id);
} }
void rewrite_rule_set::for_each_match_candidate(expr const &, match_fn const & fn) const { bool rewrite_rule_set::find_match(expr const &, match_fn const & fn) const {
auto l = m_rule_set; auto l = m_rule_set;
for (auto const & rule : l) { for (auto const & rule : l) {
if (enabled(rule) && fn(rule.m_lhs, rule.m_ceq, rule.m_is_permutation, rule.m_proof)) if (enabled(rule) && fn(rule))
return; return true;
} }
return false;
} }
void rewrite_rule_set::for_each(visit_fn const & fn) const { void rewrite_rule_set::for_each(visit_fn const & fn) const {
auto l = m_rule_set; auto l = m_rule_set;
for (auto const & rule : l) { for (auto const & rule : l) {
fn(rule.m_id, rule.m_ceq, rule.m_proof, enabled(rule)); fn(rule, enabled(rule));
} }
} }
@ -90,16 +89,16 @@ format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const {
format r; format r;
bool first = true; bool first = true;
unsigned indent = get_pp_indent(opts); unsigned indent = get_pp_indent(opts);
for_each([&](name const & name, expr const & ceq, expr const &, bool enabled) { for_each([&](rewrite_rule const & rule, bool enabled) {
if (first) if (first)
first = false; first = false;
else else
r += line(); r += line();
r += format(name); r += format(rule.get_id());
if (!enabled) if (!enabled)
r += format(" [disabled]"); r += format(" [disabled]");
r += format{space(), colon(), space()}; r += format{space(), colon(), space()};
r += nest(indent, fmt(ceq, opts)); r += nest(indent, fmt(rule.get_ceq(), opts));
}); });
return r; return r;
} }

View file

@ -15,11 +15,32 @@ Author: Leonardo de Moura
#include "kernel/formatter.h" #include "kernel/formatter.h"
namespace lean { namespace lean {
class rewrite_rule_set;
class rewrite_rule {
friend class rewrite_rule_set;
name m_id;
expr m_lhs;
expr m_rhs;
expr m_ceq;
expr m_proof;
unsigned m_num_args;
bool m_is_permutation;
rewrite_rule(name const & id, expr const & lhs, expr const & rhs, expr const & ceq, expr const & proof,
unsigned num_args, bool is_permutation);
public:
name const & get_id() const { return m_id; }
expr const & get_lhs() const { return m_lhs; }
expr const & get_rhs() const { return m_rhs; }
expr const & get_ceq() const { return m_ceq; }
expr const & get_proof() const { return m_proof; }
unsigned get_num_args() const { return m_num_args; }
bool is_permutation() const { return m_is_permutation; }
};
/** /**
\brief Actual implementation of the \c rewrite_rule_set class. \brief Actual implementation of the \c rewrite_rule_set class.
*/ */
class rewrite_rule_set { class rewrite_rule_set {
struct rewrite_rule;
typedef splay_tree<name, name_quick_cmp> name_set; typedef splay_tree<name, name_quick_cmp> name_set;
ro_environment::weak_ref m_env; ro_environment::weak_ref m_env;
list<rewrite_rule> m_rule_set; // TODO(Leo): use better data-structure, e.g., discrimination trees list<rewrite_rule> m_rule_set; // TODO(Leo): use better data-structure, e.g., discrimination trees
@ -51,22 +72,17 @@ public:
/** \brief Enable/disable the conditional rewrite rules tagged with the given identifier. */ /** \brief Enable/disable the conditional rewrite rules tagged with the given identifier. */
void enable(name const & id, bool f); void enable(name const & id, bool f);
typedef std::function<bool(expr const &, expr const &, bool is_permutation, expr const &)> match_fn; // NOLINT typedef std::function<bool(rewrite_rule const &)> match_fn; // NOLINT
typedef std::function<void(name const &, expr const &, expr const &, bool)> visit_fn; typedef std::function<void(rewrite_rule const &, bool)> visit_fn;
/** /**
\brief Execute <tt>fn(lhs, ceq, is_perm, proof)</tt> for each (enabled) rule whose the left-hand-side may \brief Execute <tt>fn(rule)</tt> for each (enabled) rule whose the left-hand-side may
match \c e. match \c e.
The traversal is interrupted as soon as \c fn returns true. The traversal is interrupted as soon as \c fn returns true.
The redundant argument \c lhs is the left-hand-side of \c ceq.
The redundant argument \c is_perm is true iff \c ceq is a permutation rule.
The argument \c proof is the proof for \c ceq.
*/ */
void for_each_match_candidate(expr const &, match_fn const & fn) const; bool find_match(expr const &, match_fn const & fn) const;
/** \brief Execute <tt>fn(id, ceq, proof, enabled)</tt> for each rule in this rule set. */ /** \brief Execute <tt>fn(rule, enabled)</tt> for each rule in this rule set. */
void for_each(visit_fn const & fn) const; void for_each(visit_fn const & fn) const;
/** \brief Pretty print this rule set. */ /** \brief Pretty print this rule set. */

View file

@ -92,6 +92,16 @@ class simplifier_fn {
bool m_unfold; bool m_unfold;
unsigned m_max_steps; unsigned m_max_steps;
struct match_fn {
simplifier_fn & m_simp;
match_fn(simplifier_fn & s):m_simp(s) {}
bool operator()(rewrite_rule const & rule) const {
return m_simp.match(rule);
}
};
match_fn m_match_fn;
struct result { struct result {
expr m_out; expr m_out;
optional<expr> m_proof; optional<expr> m_proof;
@ -126,20 +136,20 @@ class simplifier_fn {
expr mk_congr1_th(expr const & f_type, expr const & f, expr const & new_f, expr const & a, expr const & Heq_f) { expr mk_congr1_th(expr const & f_type, expr const & f, expr const & new_f, expr const & a, expr const & Heq_f) {
expr const & A = abst_domain(f_type); expr const & A = abst_domain(f_type);
expr const & B = lower_free_vars(abst_body(f_type), 1, 1); expr B = lower_free_vars(abst_body(f_type), 1, 1);
return ::lean::mk_congr1_th(A, B, f, new_f, a, Heq_f); return ::lean::mk_congr1_th(A, B, f, new_f, a, Heq_f);
} }
expr mk_congr2_th(expr const & f_type, expr const & a, expr const & new_a, expr const & f, expr const & Heq_a) { expr mk_congr2_th(expr const & f_type, expr const & a, expr const & new_a, expr const & f, expr const & Heq_a) {
expr const & A = abst_domain(f_type); expr const & A = abst_domain(f_type);
expr const & B = lower_free_vars(abst_body(f_type), 1, 1); expr B = lower_free_vars(abst_body(f_type), 1, 1);
return ::lean::mk_congr2_th(A, B, a, new_a, f, Heq_a); return ::lean::mk_congr2_th(A, B, a, new_a, f, Heq_a);
} }
expr mk_congr_th(expr const & f_type, expr const & f, expr const & new_f, expr const & a, expr const & new_a, expr mk_congr_th(expr const & f_type, expr const & f, expr const & new_f, expr const & a, expr const & new_a,
expr const & Heq_f, expr const & Heq_a) { expr const & Heq_f, expr const & Heq_a) {
expr const & A = abst_domain(f_type); expr const & A = abst_domain(f_type);
expr const & B = lower_free_vars(abst_body(f_type), 1, 1); expr B = lower_free_vars(abst_body(f_type), 1, 1);
return ::lean::mk_congr_th(A, B, f, new_f, a, new_a, Heq_f, Heq_a); return ::lean::mk_congr_th(A, B, f, new_f, a, new_a, Heq_f, Heq_a);
} }
@ -152,6 +162,65 @@ class simplifier_fn {
f, new_f, a, new_a, Heq_f, Heq_a); f, new_f, a, new_a, Heq_f, Heq_a);
} }
result mk_trans_result(expr const & a, result const & b_res, expr const & c, expr const & H_bc) {
if (m_proofs_enabled) {
if (!b_res.m_proof) {
// The proof of a = b is reflexivity
return result(c, H_bc);
} else {
expr const & b = b_res.m_out;
expr new_proof;
bool heq_proof = false;
if (b_res.m_heq_proof) {
expr b_type = infer_type(b);
new_proof = ::lean::mk_htrans_th(infer_type(a), b_type, b_type, /* b and c must have the same type */
a, b, c, *b_res.m_proof, mk_to_heq_th(b_type, b, c, H_bc));
heq_proof = true;
} else {
new_proof = ::lean::mk_trans_th(infer_type(a), a, b, c, *b_res.m_proof, H_bc);
}
return result(c, new_proof, heq_proof);
}
} else {
return result(c);
}
}
result mk_trans_result(expr const & a, result const & b_res, result const & c_res) {
if (m_proofs_enabled) {
if (!b_res.m_proof) {
// the proof of a == b is reflexivity
return c_res;
} else if (!c_res.m_proof) {
// the proof of b == c is reflexivity
return result(c_res.m_out, *b_res.m_proof, b_res.m_heq_proof);
} else {
bool heq_proof = b_res.m_heq_proof || c_res.m_heq_proof;
expr new_proof;
expr const & b = b_res.m_out;
expr const & c = c_res.m_out;
if (heq_proof) {
expr a_type = infer_type(a);
expr b_type = infer_type(b);
expr c_type = infer_type(c);
expr H_ab = *b_res.m_proof;
if (!b_res.m_heq_proof)
H_ab = mk_to_heq_th(a_type, a, b, H_ab);
expr H_bc = *c_res.m_proof;
if (!c_res.m_heq_proof)
H_bc = mk_to_heq_th(b_type, b, c, H_bc);
new_proof = ::lean::mk_htrans_th(a_type, b_type, c_type, a, b, c, H_ab, H_bc);
} else {
new_proof = ::lean::mk_trans_th(infer_type(a), a, b, c, *b_res.m_proof, *c_res.m_proof);
}
return result(c, new_proof, heq_proof);
}
} else {
// proof generation is disabled
return c_res;
}
}
expr mk_app_prefix(unsigned i, expr const & a) { expr mk_app_prefix(unsigned i, expr const & a) {
lean_assert(i > 0); lean_assert(i > 0);
if (i == 1) if (i == 1)
@ -225,9 +294,9 @@ class simplifier_fn {
} }
if (!changed) { if (!changed) {
return rewrite_app(result(e)); return rewrite(e, result(e));
} else if (!m_proofs_enabled) { } else if (!m_proofs_enabled) {
return rewrite_app(result(mk_app(new_args))); return rewrite(e, result(mk_app(new_args)));
} else { } else {
expr out = mk_app(new_args); expr out = mk_app(new_args);
unsigned i = 0; unsigned i = 0;
@ -272,11 +341,76 @@ class simplifier_fn {
pr = mk_congr1_th(f_types[i-1], f, new_f, arg(e, i), pr); pr = mk_congr1_th(f_types[i-1], f, new_f, arg(e, i), pr);
} }
} }
return rewrite_app(result(out, pr, heq_proof)); return rewrite(e, result(out, pr, heq_proof));
} }
} }
result rewrite_app(result const & r) { expr m_target; // temp field
buffer<optional<expr>> m_subst; // temp field
buffer<expr> m_new_args; // temp field
expr m_new_rhs; // temp field
expr m_new_proof; // temp field
void reset_subst(unsigned num_args) {
if (m_subst.size() < num_args) {
m_subst.resize(num_args);
m_new_args.resize(num_args+1);
}
for (unsigned i = 0; i < num_args; i++)
m_subst[i] = none_expr();
}
bool found_all_args(unsigned num_args) {
for (unsigned i = 0; i < num_args; i++) {
if (!m_subst[i])
return false;
m_new_args[i+1] = *m_subst[i];
}
return true;
}
/**
\brief Auxiliary function used by m_match_fn, it tries to match the given rule and
the expression in the temporary field \c m_target.
If it succeeds, then the resultant expression is stored in the temporary field m_new_rhs,
and the proof in m_new_proof (if proofs are enabled).
*/
bool match(rewrite_rule const & rule) {
unsigned num = rule.get_num_args();
reset_subst(num);
if (hop_match(rule.get_lhs(), m_target, m_subst)) {
if (found_all_args(num)) {
// easy case: all arguments found
m_new_rhs = instantiate(rule.get_rhs(), num, m_new_args.data() + 1);
if (m_proofs_enabled) {
if (num > 0) {
m_new_args[0] = rule.get_proof();
m_new_proof = mk_app(m_new_args);
} else {
m_new_proof = rule.get_proof();
}
}
return true;
}
// TODO(Leo): conditional rewriting
}
return false;
}
result rewrite(expr const & e, result const & r) {
m_target = r.m_out;
for (rewrite_rule_set const & rs : m_rule_sets) {
if (rs.find_match(m_target, m_match_fn)) {
// the result is in m_new_rhs and proof at m_new_proof
result new_r1 = mk_trans_result(e, r, m_new_rhs, m_new_proof);
if (m_single_pass) {
return new_r1;
} else {
result new_r2 = simplify(new_r1.m_out);
return mk_trans_result(e, new_r1, new_r2);
}
}
}
return r; return r;
} }
@ -302,17 +436,7 @@ class simplifier_fn {
} }
} }
} }
#if 1 return rewrite(e, result(e));
if (const_name(e) == "a") {
auto obj = m_env->find_object("a_eq_0");
if (obj) {
expr r = arg(obj->get_type(), 3);
return result(r, mk_constant("a_eq_0"));
}
}
#endif
return result(e);
} }
result simplify_lambda(expr const & e) { result simplify_lambda(expr const & e) {
@ -328,11 +452,11 @@ class simplifier_fn {
if (is_eqp(new_body, abst_body(e))) if (is_eqp(new_body, abst_body(e)))
return result(e); return result(e);
expr out = mk_lambda(e, new_body); expr out = mk_lambda(e, new_body);
if (!m_proofs_enabled) if (!m_proofs_enabled || !res_body.m_proof)
return result(out); return result(out);
expr body_type = infer_type(abst_body(e)); expr body_type = infer_type(abst_body(e));
expr pr = mk_funext_th(abst_domain(e), mk_lambda(e, body_type), e, out, expr pr = mk_funext_th(abst_domain(e), mk_lambda(e, body_type), e, out,
mk_lambda(e, *(res_body.m_proof))); mk_lambda(e, *res_body.m_proof));
return result(out, pr); return result(out, pr);
} }
} }
@ -350,12 +474,12 @@ class simplifier_fn {
if (is_eqp(new_body, abst_body(e))) if (is_eqp(new_body, abst_body(e)))
return result(e); return result(e);
expr out = mk_pi(abst_name(e), abst_domain(e), new_body); expr out = mk_pi(abst_name(e), abst_domain(e), new_body);
if (!m_proofs_enabled) if (!m_proofs_enabled || !res_body.m_proof)
return result(out); return result(out);
expr pr = mk_allext_th(abst_domain(e), expr pr = mk_allext_th(abst_domain(e),
mk_lambda(e, abst_body(e)), mk_lambda(e, abst_body(e)),
mk_lambda(e, abst_body(out)), mk_lambda(e, abst_body(out)),
mk_lambda(e, *(res_body.m_proof))); mk_lambda(e, *res_body.m_proof));
return result(out, pr); return result(out, pr);
} else { } else {
// if the environment does not contain heq axioms, then we don't simplify Pi's that are not forall's // if the environment does not contain heq axioms, then we don't simplify Pi's that are not forall's
@ -384,13 +508,13 @@ class simplifier_fn {
m_contextual = get_simplifier_contextual(o); m_contextual = get_simplifier_contextual(o);
m_single_pass = get_simplifier_single_pass(o); m_single_pass = get_simplifier_single_pass(o);
m_beta = get_simplifier_beta(o); m_beta = get_simplifier_beta(o);
m_unfold = true; // get_simplifier_unfold(o); m_unfold = get_simplifier_unfold(o);
m_max_steps = get_simplifier_max_steps(o); m_max_steps = get_simplifier_max_steps(o);
} }
public: public:
simplifier_fn(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs): simplifier_fn(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs):
m_env(env), m_tc(env), m_rule_sets(rs, rs + num_rs) { m_env(env), m_tc(env), m_rule_sets(rs, rs + num_rs), m_match_fn(*this) {
m_has_heq = m_env->imported("heq"); m_has_heq = m_env->imported("heq");
set_options(o); set_options(o);
} }

30
tests/lua/simp1.lua Normal file
View file

@ -0,0 +1,30 @@
mk_rewrite_rule_set()
add_rewrite_rules({"Nat", "add_zerol"})
add_rewrite_rules({"Nat", "add_zeror"})
parse_lean_cmds([[
variable f : Nat -> Nat -> Nat
variable g : Nat -> Nat
variable b : Nat
definition a := 1
theorem a_eq_1 : a = 1
:= refl a
definition c := 1
set_opaque a true
axiom f_id (x : Nat) : f x 1 = 2*x
]])
add_rewrite_rules("a_eq_1")
add_rewrite_rules("f_id")
-- set_option({"lean", "pp", "implicit"}, true)
e, pr = simplify(parse_lean('fun x, f (f x (0 + a)) (g (b + 0))'))
print(e)
print(pr)
local env = get_environment()
print(env:type_check(pr))
e, pr = simplify(parse_lean('forall x, let d := a + 1 in f x a >= d'))
print(e)
print(pr)
local env = get_environment()
print(env:type_check(pr))