feat(kernel/simplifier): add support for Beta-reduction in the simplifier

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-01-19 00:04:44 -08:00
parent 7a3aab60c6
commit ed009f4c88
2 changed files with 45 additions and 12 deletions

View file

@ -38,6 +38,10 @@ Author: Leonardo de Moura
#define LEAN_SIMPLIFIER_UNFOLD false
#endif
#ifndef LEAN_SIMPLIFIER_CONDITIONAL
#define LEAN_SIMPLIFIER_CONDITIONAL true
#endif
#ifndef LEAN_SIMPLIFIER_MAX_STEPS
#define LEAN_SIMPLIFIER_MAX_STEPS std::numeric_limits<unsigned>::max()
#endif
@ -48,6 +52,7 @@ static name g_simplifier_contextual {"simplifier", "contextual"};
static name g_simplifier_single_pass {"simplifier", "single_pass"};
static name g_simplifier_beta {"simplifier", "beta"};
static name g_simplifier_unfold {"simplifier", "unfold"};
static name g_simplifier_conditional {"simplifier", "conditional"};
static name g_simplifier_max_steps {"simplifier", "max_steps"};
RegisterBoolOption(g_simplifier_proofs, LEAN_SIMPLIFIER_PROOFS, "(simplifier) generate proofs");
@ -55,6 +60,7 @@ RegisterBoolOption(g_simplifier_contextual, LEAN_SIMPLIFIER_CONTEXTUAL, "(simpli
RegisterBoolOption(g_simplifier_single_pass, LEAN_SIMPLIFIER_SINGLE_PASS, "(simplifier) if false then the simplifier keeps applying simplifications as long as possible");
RegisterBoolOption(g_simplifier_beta, LEAN_SIMPLIFIER_BETA, "(simplifier) use beta-reductions");
RegisterBoolOption(g_simplifier_unfold, LEAN_SIMPLIFIER_UNFOLD, "(simplifier) unfolds non-opaque definitions");
RegisterBoolOption(g_simplifier_conditional, LEAN_SIMPLIFIER_CONDITIONAL, "(simplifier) conditional rewriting");
RegisterUnsignedOption(g_simplifier_max_steps, LEAN_SIMPLIFIER_MAX_STEPS, "(simplifier) maximum number of steps");
bool get_simplifier_proofs(options const & opts) {
@ -72,6 +78,9 @@ bool get_simplifier_beta(options const & opts) {
bool get_simplifier_unfold(options const & opts) {
return opts.get_bool(g_simplifier_unfold, LEAN_SIMPLIFIER_UNFOLD);
}
bool get_simplifier_conditional(options const & opts) {
return opts.get_bool(g_simplifier_conditional, LEAN_SIMPLIFIER_CONDITIONAL);
}
unsigned get_simplifier_max_steps(options const & opts) {
return opts.get_unsigned(g_simplifier_max_steps, LEAN_SIMPLIFIER_MAX_STEPS);
}
@ -90,6 +99,7 @@ class simplifier_fn {
bool m_single_pass;
bool m_beta;
bool m_unfold;
bool m_conditional;
unsigned m_max_steps;
struct match_fn {
@ -106,8 +116,12 @@ class simplifier_fn {
expr m_out;
optional<expr> m_proof;
bool m_heq_proof;
explicit result(expr const & out, bool heq_proof = false):m_out(out), m_heq_proof(heq_proof) {}
result(expr const & out, expr const & pr, bool heq_proof = false):m_out(out), m_proof(pr), m_heq_proof(heq_proof) {}
explicit result(expr const & out, bool heq_proof = false):
m_out(out), m_heq_proof(heq_proof) {}
result(expr const & out, expr const & pr, bool heq_proof = false):
m_out(out), m_proof(pr), m_heq_proof(heq_proof) {}
result(expr const & out, optional<expr> const & pr, bool heq_proof = false):
m_out(out), m_proof(pr), m_heq_proof(heq_proof) {}
};
struct set_context {
@ -294,9 +308,9 @@ class simplifier_fn {
}
if (!changed) {
return rewrite(e, result(e));
return rewrite_app(e, result(e));
} else if (!m_proofs_enabled) {
return rewrite(e, result(mk_app(new_args)));
return rewrite_app(e, result(mk_app(new_args)));
} else {
expr out = mk_app(new_args);
unsigned i = 0;
@ -310,7 +324,7 @@ class simplifier_fn {
bool heq_proof = false;
if (i == 0) {
pr = *(proofs[0]);
heq_proof = heq_proofs[0];
heq_proof = m_has_heq && heq_proofs[0];
} else if (m_has_heq && heq_proofs[i]) {
expr f = mk_app_prefix(i, new_args);
pr = mk_hcongr_th(f_types[i-1], f_types[i-1], f, f, arg(e, i), new_args[i],
@ -341,10 +355,17 @@ class simplifier_fn {
pr = mk_congr1_th(f_types[i-1], f, new_f, arg(e, i), pr);
}
}
return rewrite(e, result(out, pr, heq_proof));
return rewrite_app(e, result(out, pr, heq_proof));
}
}
result rewrite_app(expr const & e, result const & r) {
if (m_beta && is_lambda(arg(r.m_out, 0)))
return rewrite(e, result(head_beta_reduce(r.m_out), r.m_proof, r.m_heq_proof));
else
return rewrite(e, r);
}
expr m_target; // temp field
buffer<optional<expr>> m_subst; // temp field
buffer<expr> m_new_args; // temp field
@ -439,6 +460,11 @@ class simplifier_fn {
return rewrite(e, result(e));
}
result rewrite_lambda(expr const & e, result const & r) {
rewrite(e, r);
}
result simplify_lambda(expr const & e) {
lean_assert(is_lambda(e));
if (m_has_heq) {
@ -450,14 +476,14 @@ class simplifier_fn {
lean_assert(!res_body.m_heq_proof);
expr new_body = res_body.m_out;
if (is_eqp(new_body, abst_body(e)))
return result(e);
return rewrite_lambda(e, result(e));
expr out = mk_lambda(e, new_body);
if (!m_proofs_enabled || !res_body.m_proof)
return result(out);
return rewrite_lambda(e, result(out));
expr body_type = infer_type(abst_body(e));
expr pr = mk_funext_th(abst_domain(e), mk_lambda(e, body_type), e, out,
mk_lambda(e, *res_body.m_proof));
return result(out, pr);
return rewrite_lambda(e, result(out, pr));
}
}
@ -472,15 +498,15 @@ class simplifier_fn {
lean_assert(!res_body.m_heq_proof);
expr new_body = res_body.m_out;
if (is_eqp(new_body, abst_body(e)))
return result(e);
return rewrite(e, result(e));
expr out = mk_pi(abst_name(e), abst_domain(e), new_body);
if (!m_proofs_enabled || !res_body.m_proof)
return result(out);
return rewrite(e, result(out));
expr pr = mk_allext_th(abst_domain(e),
mk_lambda(e, abst_body(e)),
mk_lambda(e, abst_body(out)),
mk_lambda(e, *res_body.m_proof));
return result(out, pr);
return rewrite(e, result(out, pr));
} else {
// if the environment does not contain heq axioms, then we don't simplify Pi's that are not forall's
return result(e);
@ -509,6 +535,7 @@ class simplifier_fn {
m_single_pass = get_simplifier_single_pass(o);
m_beta = get_simplifier_beta(o);
m_unfold = get_simplifier_unfold(o);
m_conditional = get_simplifier_conditional(o);
m_max_steps = get_simplifier_max_steps(o);
}

View file

@ -28,3 +28,9 @@ print(e)
print(pr)
local env = get_environment()
print(env:type_check(pr))
e, pr = simplify(parse_lean('(fun x, f (f x (0 + a)) (g (b + 0))) b'))
print(e)
print(pr)
local env = get_environment()
print(env:type_check(pr))