refactor(library/simplifier): cleanup rewrite_rule_set, and use it in the simplifier
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
466285c577
commit
32c5bc25e3
4 changed files with 227 additions and 58 deletions
|
@ -15,15 +15,10 @@ Author: Leonardo de Moura
|
|||
#include "library/simplifier/rewrite_rule_set.h"
|
||||
|
||||
namespace lean {
|
||||
struct rewrite_rule_set::rewrite_rule {
|
||||
name m_id;
|
||||
expr m_lhs;
|
||||
expr m_ceq;
|
||||
expr m_proof;
|
||||
bool m_is_permutation;
|
||||
rewrite_rule(name const & id, expr const & lhs, expr const & ceq, expr const & proof, bool is_permutation):
|
||||
m_id(id), m_lhs(lhs), m_ceq(ceq), m_proof(proof), m_is_permutation(is_permutation) {}
|
||||
};
|
||||
rewrite_rule::rewrite_rule(name const & id, expr const & lhs, expr const & rhs, expr const & ceq, expr const & proof,
|
||||
unsigned num_args, bool is_permutation):
|
||||
m_id(id), m_lhs(lhs), m_rhs(rhs), m_ceq(ceq), m_proof(proof), m_num_args(num_args), m_is_permutation(is_permutation) {
|
||||
}
|
||||
|
||||
rewrite_rule_set::rewrite_rule_set(ro_environment const & env):m_env(env.to_weak_ref()) {}
|
||||
rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & other):
|
||||
|
@ -36,13 +31,16 @@ void rewrite_rule_set::insert(name const & id, expr const & th, expr const & pro
|
|||
expr const & ceq = p.first;
|
||||
expr const & proof = p.second;
|
||||
bool is_perm = is_permutation_ceq(ceq);
|
||||
expr lhs = ceq;
|
||||
while (is_pi(lhs)) {
|
||||
lhs = abst_body(lhs);
|
||||
expr eq = ceq;
|
||||
unsigned num = 0;
|
||||
while (is_pi(eq)) {
|
||||
eq = abst_body(eq);
|
||||
num++;
|
||||
}
|
||||
lean_assert(is_equality(lhs));
|
||||
lhs = arg(lhs, num_args(lhs) - 2);
|
||||
m_rule_set.emplace_front(id, lhs, ceq, proof, is_perm);
|
||||
lean_assert(is_equality(eq));
|
||||
m_rule_set = cons(rewrite_rule(id, arg(eq, num_args(eq) - 2), arg(eq, num_args(eq) - 1),
|
||||
ceq, proof, num, is_perm),
|
||||
m_rule_set);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,7 +55,7 @@ void rewrite_rule_set::insert(name const & th_name) {
|
|||
}
|
||||
|
||||
bool rewrite_rule_set::enabled(rewrite_rule const & rule) const {
|
||||
return !m_disabled_rules.contains(rule.m_id);
|
||||
return !m_disabled_rules.contains(rule.get_id());
|
||||
}
|
||||
|
||||
bool rewrite_rule_set::enabled(name const & id) const {
|
||||
|
@ -71,18 +69,19 @@ void rewrite_rule_set::enable(name const & id, bool f) {
|
|||
m_disabled_rules.insert(id);
|
||||
}
|
||||
|
||||
void rewrite_rule_set::for_each_match_candidate(expr const &, match_fn const & fn) const {
|
||||
bool rewrite_rule_set::find_match(expr const &, match_fn const & fn) const {
|
||||
auto l = m_rule_set;
|
||||
for (auto const & rule : l) {
|
||||
if (enabled(rule) && fn(rule.m_lhs, rule.m_ceq, rule.m_is_permutation, rule.m_proof))
|
||||
return;
|
||||
if (enabled(rule) && fn(rule))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void rewrite_rule_set::for_each(visit_fn const & fn) const {
|
||||
auto l = m_rule_set;
|
||||
for (auto const & rule : l) {
|
||||
fn(rule.m_id, rule.m_ceq, rule.m_proof, enabled(rule));
|
||||
fn(rule, enabled(rule));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,16 +89,16 @@ format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const {
|
|||
format r;
|
||||
bool first = true;
|
||||
unsigned indent = get_pp_indent(opts);
|
||||
for_each([&](name const & name, expr const & ceq, expr const &, bool enabled) {
|
||||
for_each([&](rewrite_rule const & rule, bool enabled) {
|
||||
if (first)
|
||||
first = false;
|
||||
else
|
||||
r += line();
|
||||
r += format(name);
|
||||
r += format(rule.get_id());
|
||||
if (!enabled)
|
||||
r += format(" [disabled]");
|
||||
r += format{space(), colon(), space()};
|
||||
r += nest(indent, fmt(ceq, opts));
|
||||
r += nest(indent, fmt(rule.get_ceq(), opts));
|
||||
});
|
||||
return r;
|
||||
}
|
||||
|
|
|
@ -15,11 +15,32 @@ Author: Leonardo de Moura
|
|||
#include "kernel/formatter.h"
|
||||
|
||||
namespace lean {
|
||||
class rewrite_rule_set;
|
||||
class rewrite_rule {
|
||||
friend class rewrite_rule_set;
|
||||
name m_id;
|
||||
expr m_lhs;
|
||||
expr m_rhs;
|
||||
expr m_ceq;
|
||||
expr m_proof;
|
||||
unsigned m_num_args;
|
||||
bool m_is_permutation;
|
||||
rewrite_rule(name const & id, expr const & lhs, expr const & rhs, expr const & ceq, expr const & proof,
|
||||
unsigned num_args, bool is_permutation);
|
||||
public:
|
||||
name const & get_id() const { return m_id; }
|
||||
expr const & get_lhs() const { return m_lhs; }
|
||||
expr const & get_rhs() const { return m_rhs; }
|
||||
expr const & get_ceq() const { return m_ceq; }
|
||||
expr const & get_proof() const { return m_proof; }
|
||||
unsigned get_num_args() const { return m_num_args; }
|
||||
bool is_permutation() const { return m_is_permutation; }
|
||||
};
|
||||
|
||||
/**
|
||||
\brief Actual implementation of the \c rewrite_rule_set class.
|
||||
*/
|
||||
class rewrite_rule_set {
|
||||
struct rewrite_rule;
|
||||
typedef splay_tree<name, name_quick_cmp> name_set;
|
||||
ro_environment::weak_ref m_env;
|
||||
list<rewrite_rule> m_rule_set; // TODO(Leo): use better data-structure, e.g., discrimination trees
|
||||
|
@ -51,22 +72,17 @@ public:
|
|||
/** \brief Enable/disable the conditional rewrite rules tagged with the given identifier. */
|
||||
void enable(name const & id, bool f);
|
||||
|
||||
typedef std::function<bool(expr const &, expr const &, bool is_permutation, expr const &)> match_fn; // NOLINT
|
||||
typedef std::function<void(name const &, expr const &, expr const &, bool)> visit_fn;
|
||||
typedef std::function<bool(rewrite_rule const &)> match_fn; // NOLINT
|
||||
typedef std::function<void(rewrite_rule const &, bool)> visit_fn;
|
||||
|
||||
/**
|
||||
\brief Execute <tt>fn(lhs, ceq, is_perm, proof)</tt> for each (enabled) rule whose the left-hand-side may
|
||||
\brief Execute <tt>fn(rule)</tt> for each (enabled) rule whose the left-hand-side may
|
||||
match \c e.
|
||||
The traversal is interrupted as soon as \c fn returns true.
|
||||
|
||||
The redundant argument \c lhs is the left-hand-side of \c ceq.
|
||||
The redundant argument \c is_perm is true iff \c ceq is a permutation rule.
|
||||
|
||||
The argument \c proof is the proof for \c ceq.
|
||||
*/
|
||||
void for_each_match_candidate(expr const &, match_fn const & fn) const;
|
||||
bool find_match(expr const &, match_fn const & fn) const;
|
||||
|
||||
/** \brief Execute <tt>fn(id, ceq, proof, enabled)</tt> for each rule in this rule set. */
|
||||
/** \brief Execute <tt>fn(rule, enabled)</tt> for each rule in this rule set. */
|
||||
void for_each(visit_fn const & fn) const;
|
||||
|
||||
/** \brief Pretty print this rule set. */
|
||||
|
|
|
@ -92,6 +92,16 @@ class simplifier_fn {
|
|||
bool m_unfold;
|
||||
unsigned m_max_steps;
|
||||
|
||||
struct match_fn {
|
||||
simplifier_fn & m_simp;
|
||||
match_fn(simplifier_fn & s):m_simp(s) {}
|
||||
bool operator()(rewrite_rule const & rule) const {
|
||||
return m_simp.match(rule);
|
||||
}
|
||||
};
|
||||
|
||||
match_fn m_match_fn;
|
||||
|
||||
struct result {
|
||||
expr m_out;
|
||||
optional<expr> m_proof;
|
||||
|
@ -126,20 +136,20 @@ class simplifier_fn {
|
|||
|
||||
expr mk_congr1_th(expr const & f_type, expr const & f, expr const & new_f, expr const & a, expr const & Heq_f) {
|
||||
expr const & A = abst_domain(f_type);
|
||||
expr const & B = lower_free_vars(abst_body(f_type), 1, 1);
|
||||
expr B = lower_free_vars(abst_body(f_type), 1, 1);
|
||||
return ::lean::mk_congr1_th(A, B, f, new_f, a, Heq_f);
|
||||
}
|
||||
|
||||
expr mk_congr2_th(expr const & f_type, expr const & a, expr const & new_a, expr const & f, expr const & Heq_a) {
|
||||
expr const & A = abst_domain(f_type);
|
||||
expr const & B = lower_free_vars(abst_body(f_type), 1, 1);
|
||||
expr B = lower_free_vars(abst_body(f_type), 1, 1);
|
||||
return ::lean::mk_congr2_th(A, B, a, new_a, f, Heq_a);
|
||||
}
|
||||
|
||||
expr mk_congr_th(expr const & f_type, expr const & f, expr const & new_f, expr const & a, expr const & new_a,
|
||||
expr const & Heq_f, expr const & Heq_a) {
|
||||
expr const & A = abst_domain(f_type);
|
||||
expr const & B = lower_free_vars(abst_body(f_type), 1, 1);
|
||||
expr B = lower_free_vars(abst_body(f_type), 1, 1);
|
||||
return ::lean::mk_congr_th(A, B, f, new_f, a, new_a, Heq_f, Heq_a);
|
||||
}
|
||||
|
||||
|
@ -152,6 +162,65 @@ class simplifier_fn {
|
|||
f, new_f, a, new_a, Heq_f, Heq_a);
|
||||
}
|
||||
|
||||
result mk_trans_result(expr const & a, result const & b_res, expr const & c, expr const & H_bc) {
|
||||
if (m_proofs_enabled) {
|
||||
if (!b_res.m_proof) {
|
||||
// The proof of a = b is reflexivity
|
||||
return result(c, H_bc);
|
||||
} else {
|
||||
expr const & b = b_res.m_out;
|
||||
expr new_proof;
|
||||
bool heq_proof = false;
|
||||
if (b_res.m_heq_proof) {
|
||||
expr b_type = infer_type(b);
|
||||
new_proof = ::lean::mk_htrans_th(infer_type(a), b_type, b_type, /* b and c must have the same type */
|
||||
a, b, c, *b_res.m_proof, mk_to_heq_th(b_type, b, c, H_bc));
|
||||
heq_proof = true;
|
||||
} else {
|
||||
new_proof = ::lean::mk_trans_th(infer_type(a), a, b, c, *b_res.m_proof, H_bc);
|
||||
}
|
||||
return result(c, new_proof, heq_proof);
|
||||
}
|
||||
} else {
|
||||
return result(c);
|
||||
}
|
||||
}
|
||||
|
||||
result mk_trans_result(expr const & a, result const & b_res, result const & c_res) {
|
||||
if (m_proofs_enabled) {
|
||||
if (!b_res.m_proof) {
|
||||
// the proof of a == b is reflexivity
|
||||
return c_res;
|
||||
} else if (!c_res.m_proof) {
|
||||
// the proof of b == c is reflexivity
|
||||
return result(c_res.m_out, *b_res.m_proof, b_res.m_heq_proof);
|
||||
} else {
|
||||
bool heq_proof = b_res.m_heq_proof || c_res.m_heq_proof;
|
||||
expr new_proof;
|
||||
expr const & b = b_res.m_out;
|
||||
expr const & c = c_res.m_out;
|
||||
if (heq_proof) {
|
||||
expr a_type = infer_type(a);
|
||||
expr b_type = infer_type(b);
|
||||
expr c_type = infer_type(c);
|
||||
expr H_ab = *b_res.m_proof;
|
||||
if (!b_res.m_heq_proof)
|
||||
H_ab = mk_to_heq_th(a_type, a, b, H_ab);
|
||||
expr H_bc = *c_res.m_proof;
|
||||
if (!c_res.m_heq_proof)
|
||||
H_bc = mk_to_heq_th(b_type, b, c, H_bc);
|
||||
new_proof = ::lean::mk_htrans_th(a_type, b_type, c_type, a, b, c, H_ab, H_bc);
|
||||
} else {
|
||||
new_proof = ::lean::mk_trans_th(infer_type(a), a, b, c, *b_res.m_proof, *c_res.m_proof);
|
||||
}
|
||||
return result(c, new_proof, heq_proof);
|
||||
}
|
||||
} else {
|
||||
// proof generation is disabled
|
||||
return c_res;
|
||||
}
|
||||
}
|
||||
|
||||
expr mk_app_prefix(unsigned i, expr const & a) {
|
||||
lean_assert(i > 0);
|
||||
if (i == 1)
|
||||
|
@ -225,9 +294,9 @@ class simplifier_fn {
|
|||
}
|
||||
|
||||
if (!changed) {
|
||||
return rewrite_app(result(e));
|
||||
return rewrite(e, result(e));
|
||||
} else if (!m_proofs_enabled) {
|
||||
return rewrite_app(result(mk_app(new_args)));
|
||||
return rewrite(e, result(mk_app(new_args)));
|
||||
} else {
|
||||
expr out = mk_app(new_args);
|
||||
unsigned i = 0;
|
||||
|
@ -272,11 +341,76 @@ class simplifier_fn {
|
|||
pr = mk_congr1_th(f_types[i-1], f, new_f, arg(e, i), pr);
|
||||
}
|
||||
}
|
||||
return rewrite_app(result(out, pr, heq_proof));
|
||||
return rewrite(e, result(out, pr, heq_proof));
|
||||
}
|
||||
}
|
||||
|
||||
result rewrite_app(result const & r) {
|
||||
expr m_target; // temp field
|
||||
buffer<optional<expr>> m_subst; // temp field
|
||||
buffer<expr> m_new_args; // temp field
|
||||
expr m_new_rhs; // temp field
|
||||
expr m_new_proof; // temp field
|
||||
|
||||
void reset_subst(unsigned num_args) {
|
||||
if (m_subst.size() < num_args) {
|
||||
m_subst.resize(num_args);
|
||||
m_new_args.resize(num_args+1);
|
||||
}
|
||||
for (unsigned i = 0; i < num_args; i++)
|
||||
m_subst[i] = none_expr();
|
||||
}
|
||||
|
||||
bool found_all_args(unsigned num_args) {
|
||||
for (unsigned i = 0; i < num_args; i++) {
|
||||
if (!m_subst[i])
|
||||
return false;
|
||||
m_new_args[i+1] = *m_subst[i];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Auxiliary function used by m_match_fn, it tries to match the given rule and
|
||||
the expression in the temporary field \c m_target.
|
||||
If it succeeds, then the resultant expression is stored in the temporary field m_new_rhs,
|
||||
and the proof in m_new_proof (if proofs are enabled).
|
||||
*/
|
||||
bool match(rewrite_rule const & rule) {
|
||||
unsigned num = rule.get_num_args();
|
||||
reset_subst(num);
|
||||
if (hop_match(rule.get_lhs(), m_target, m_subst)) {
|
||||
if (found_all_args(num)) {
|
||||
// easy case: all arguments found
|
||||
m_new_rhs = instantiate(rule.get_rhs(), num, m_new_args.data() + 1);
|
||||
if (m_proofs_enabled) {
|
||||
if (num > 0) {
|
||||
m_new_args[0] = rule.get_proof();
|
||||
m_new_proof = mk_app(m_new_args);
|
||||
} else {
|
||||
m_new_proof = rule.get_proof();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
// TODO(Leo): conditional rewriting
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
result rewrite(expr const & e, result const & r) {
|
||||
m_target = r.m_out;
|
||||
for (rewrite_rule_set const & rs : m_rule_sets) {
|
||||
if (rs.find_match(m_target, m_match_fn)) {
|
||||
// the result is in m_new_rhs and proof at m_new_proof
|
||||
result new_r1 = mk_trans_result(e, r, m_new_rhs, m_new_proof);
|
||||
if (m_single_pass) {
|
||||
return new_r1;
|
||||
} else {
|
||||
result new_r2 = simplify(new_r1.m_out);
|
||||
return mk_trans_result(e, new_r1, new_r2);
|
||||
}
|
||||
}
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
|
@ -302,17 +436,7 @@ class simplifier_fn {
|
|||
}
|
||||
}
|
||||
}
|
||||
#if 1
|
||||
if (const_name(e) == "a") {
|
||||
auto obj = m_env->find_object("a_eq_0");
|
||||
if (obj) {
|
||||
expr r = arg(obj->get_type(), 3);
|
||||
return result(r, mk_constant("a_eq_0"));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return result(e);
|
||||
return rewrite(e, result(e));
|
||||
}
|
||||
|
||||
result simplify_lambda(expr const & e) {
|
||||
|
@ -328,11 +452,11 @@ class simplifier_fn {
|
|||
if (is_eqp(new_body, abst_body(e)))
|
||||
return result(e);
|
||||
expr out = mk_lambda(e, new_body);
|
||||
if (!m_proofs_enabled)
|
||||
if (!m_proofs_enabled || !res_body.m_proof)
|
||||
return result(out);
|
||||
expr body_type = infer_type(abst_body(e));
|
||||
expr pr = mk_funext_th(abst_domain(e), mk_lambda(e, body_type), e, out,
|
||||
mk_lambda(e, *(res_body.m_proof)));
|
||||
mk_lambda(e, *res_body.m_proof));
|
||||
return result(out, pr);
|
||||
}
|
||||
}
|
||||
|
@ -350,12 +474,12 @@ class simplifier_fn {
|
|||
if (is_eqp(new_body, abst_body(e)))
|
||||
return result(e);
|
||||
expr out = mk_pi(abst_name(e), abst_domain(e), new_body);
|
||||
if (!m_proofs_enabled)
|
||||
if (!m_proofs_enabled || !res_body.m_proof)
|
||||
return result(out);
|
||||
expr pr = mk_allext_th(abst_domain(e),
|
||||
mk_lambda(e, abst_body(e)),
|
||||
mk_lambda(e, abst_body(out)),
|
||||
mk_lambda(e, *(res_body.m_proof)));
|
||||
mk_lambda(e, *res_body.m_proof));
|
||||
return result(out, pr);
|
||||
} else {
|
||||
// if the environment does not contain heq axioms, then we don't simplify Pi's that are not forall's
|
||||
|
@ -384,13 +508,13 @@ class simplifier_fn {
|
|||
m_contextual = get_simplifier_contextual(o);
|
||||
m_single_pass = get_simplifier_single_pass(o);
|
||||
m_beta = get_simplifier_beta(o);
|
||||
m_unfold = true; // get_simplifier_unfold(o);
|
||||
m_unfold = get_simplifier_unfold(o);
|
||||
m_max_steps = get_simplifier_max_steps(o);
|
||||
}
|
||||
|
||||
public:
|
||||
simplifier_fn(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs):
|
||||
m_env(env), m_tc(env), m_rule_sets(rs, rs + num_rs) {
|
||||
m_env(env), m_tc(env), m_rule_sets(rs, rs + num_rs), m_match_fn(*this) {
|
||||
m_has_heq = m_env->imported("heq");
|
||||
set_options(o);
|
||||
}
|
||||
|
|
30
tests/lua/simp1.lua
Normal file
30
tests/lua/simp1.lua
Normal file
|
@ -0,0 +1,30 @@
|
|||
mk_rewrite_rule_set()
|
||||
add_rewrite_rules({"Nat", "add_zerol"})
|
||||
add_rewrite_rules({"Nat", "add_zeror"})
|
||||
parse_lean_cmds([[
|
||||
variable f : Nat -> Nat -> Nat
|
||||
variable g : Nat -> Nat
|
||||
variable b : Nat
|
||||
definition a := 1
|
||||
theorem a_eq_1 : a = 1
|
||||
:= refl a
|
||||
definition c := 1
|
||||
set_opaque a true
|
||||
|
||||
axiom f_id (x : Nat) : f x 1 = 2*x
|
||||
]])
|
||||
add_rewrite_rules("a_eq_1")
|
||||
add_rewrite_rules("f_id")
|
||||
-- set_option({"lean", "pp", "implicit"}, true)
|
||||
e, pr = simplify(parse_lean('fun x, f (f x (0 + a)) (g (b + 0))'))
|
||||
print(e)
|
||||
print(pr)
|
||||
local env = get_environment()
|
||||
print(env:type_check(pr))
|
||||
|
||||
|
||||
e, pr = simplify(parse_lean('forall x, let d := a + 1 in f x a >= d'))
|
||||
print(e)
|
||||
print(pr)
|
||||
local env = get_environment()
|
||||
print(env:type_check(pr))
|
Loading…
Reference in a new issue