feat(library/blast/fusion): refactor

This commit is contained in:
Daniel Selsam 2015-11-13 11:29:09 -08:00
parent d852be0d79
commit f72da014d4
4 changed files with 140 additions and 51 deletions

View file

@ -21,6 +21,7 @@ Author: Daniel Selsam
#include "util/pair.h"
#include "util/sexpr/option_declarations.h"
#include <functional>
#include <iostream>
#ifndef LEAN_DEFAULT_SIMPLIFY_MAX_STEPS
#define LEAN_DEFAULT_SIMPLIFY_MAX_STEPS 1000
@ -95,6 +96,25 @@ bool get_simplify_fuse() {
return ios().get_options().get_bool(*g_simplify_fuse, LEAN_DEFAULT_SIMPLIFY_FUSE);
}
/* Miscellaneous helpers */
static bool is_const_app(expr const & e, name const & n, unsigned nargs) {
expr const & f = get_app_fn(e);
return is_constant(f) && const_name(f) == n && get_app_num_args(e) == nargs;
}
static bool is_add_app(expr const & e) {
return is_const_app(e, get_add_name(), 4);
}
static bool is_mul_app(expr const & e) {
return is_const_app(e, get_mul_name(), 4);
}
static bool is_neg_app(expr const & e) {
return is_const_app(e, get_neg_name(), 3);
}
/* Main simplifier class */
class simplifier {
@ -166,6 +186,10 @@ class simplifier {
return srss;
}
bool instantiate_emetas(blast_tmp_type_context & tmp_tctx, unsigned num_emeta,
list<expr> const & emetas, list<bool> const & instances);
/* Results */
result lift_from_eq(expr const & x, result const & r);
result join(result const & r1, result const & r2);
@ -173,12 +197,15 @@ class simplifier {
result finalize(result const & r);
/* Simplification */
result simplify(expr const & e);
result simplify(expr const & e, bool is_root);
result simplify_lambda(expr const & e);
result simplify_pi(expr const & e);
result simplify_app(expr const & e);
result simplify_fun(expr const & e);
/* Proving */
optional<expr> prove(expr const & thm);
/* Rewriting */
result rewrite(expr const & e);
result rewrite(expr const & e, simp_rule_sets const & srss);
@ -193,12 +220,20 @@ class simplifier {
result try_congrs(expr const & e);
result try_congr(expr const & e, congr_rule const & cr);
bool instantiate_emetas(blast_tmp_type_context & tmp_tctx, unsigned num_emeta,
list<expr> const & emetas, list<bool> const & instances);
template<typename F>
optional<result> synth_congr(expr const & e, F && simp);
/* Fusion */
std::array<bool, 1> tc_mask1{{true}};
std::array<bool, 2> tc_mask2{{true, false}};
result maybe_fuse(expr const & e, bool is_root);
result fuse(expr const & e);
expr_pair split_summand(expr const & e, expr const & f_mul, expr const & one);
public:
simplifier(name const & rel, simp_rule_sets const & srss);
result operator()(expr const & e) { return simplify(e); }
result operator()(expr const & e) { return simplify(e, true); }
};
/* Constructor */
@ -257,7 +292,7 @@ result simplifier::finalize(result const & r) {
/* Simplification */
result simplifier::simplify(expr const & e) {
result simplifier::simplify(expr const & e, bool is_root) {
m_num_steps++;
flet<unsigned> inc_depth(m_depth, m_depth+1);
@ -289,7 +324,7 @@ result simplifier::simplify(expr const & e) {
lean_unreachable();
case expr_kind::Macro:
if (m_expand_macros) {
if (auto m = m_tmp_tctx->expand_macro(e)) r = join(r, simplify(whnf(*m)));
if (auto m = m_tmp_tctx->expand_macro(e)) r = join(r, simplify(whnf(*m), is_root));
}
break;
case expr_kind::Lambda:
@ -308,12 +343,12 @@ result simplifier::simplify(expr const & e) {
if (r.get_new() == e && !using_eq()) {
{
flet<name> use_eq(m_rel, get_eq_name());
r = simplify(r.get_new());
r = simplify(r.get_new(), is_root);
}
if (!r.is_none()) r = lift_from_eq(e, r);
}
if (m_exhaustive && r.get_new() != e) r = join(r, simplify(r.get_new()));
if (m_exhaustive && r.get_new() != e) r = join(r, simplify(r.get_new(), is_root));
if (m_memoize) cache_save(e, r);
@ -332,7 +367,7 @@ result simplifier::simplify_lambda(expr const & _e) {
e = instantiate(binding_body(e), l);
}
result r = simplify(e);
result r = simplify(e, false);
if (r.is_none()) { return result(_e); }
for (int i = ls.size() - 1; i >= 0; --i) r = funext(r, ls[i]);
@ -357,38 +392,10 @@ result simplifier::simplify_app(expr const & e) {
/* (2) Synthesize congruence lemma */
if (using_eq()) {
buffer<expr> args;
expr fn = get_app_args(e, args);
if (auto congr_lemma = mk_congr_lemma_for_simp(fn, args.size())) {
expr proof = congr_lemma->get_proof();
expr type = congr_lemma->get_type();
unsigned i = 0;
bool simplified = false;
buffer<expr> locals;
for_each(congr_lemma->get_arg_kinds(), [&](congr_arg_kind const & ckind) {
proof = mk_app(proof, args[i]);
type = instantiate(binding_body(type), args[i]);
if (ckind == congr_arg_kind::Eq) {
result r_arg = simplify(args[i]);
if (!r_arg.is_none()) simplified = true;
r_arg = finalize(r_arg);
proof = mk_app(proof, r_arg.get_new(), r_arg.get_proof());
type = instantiate(binding_body(type), r_arg.get_new());
type = instantiate(binding_body(type), r_arg.get_proof());
}
i++;
});
if (simplified) {
lean_assert(is_eq(type));
buffer<expr> type_args;
get_app_args(type, type_args);
expr & new_e = type_args[2];
return join(result(new_e, proof), simplify_fun(new_e));
} else {
return simplify_fun(e);
}
}
optional<result> r_args = synth_congr(e, [&](expr const & e) {
return simplify(e, false);
});
if (r_args) return join(*r_args, simplify_fun(r_args->get_new()));
}
/* (3) Fall back on generic binary congruence */
@ -396,13 +403,16 @@ result simplifier::simplify_app(expr const & e) {
expr const & f = app_fn(e);
expr const & arg = app_arg(e);
result r_f = simplify(f);
// TODO(dhs): it is not clear if this recursive call should be considered
// a root or not, though does not matter since if + were being applied,
// we would have synthesized a congruence rule in step (2).
result r_f = simplify(f, false);
if (is_dependent_fn(f)) {
if (r_f.is_none()) return e;
else return congr_fun(r_f, arg);
} else {
result r_arg = simplify(arg);
result r_arg = simplify(arg, false);
if (r_f.is_none() && r_arg.is_none()) return e;
else if (r_f.is_none()) return congr_arg(f, r_arg);
else if (r_arg.is_none()) return congr_fun(r_f, arg);
@ -417,9 +427,23 @@ result simplifier::simplify_fun(expr const & e) {
lean_assert(is_app(e));
buffer<expr> args;
expr const & f = get_app_args(e, args);
result r_f = simplify(f);
result r_f = simplify(f, true);
if (r_f.is_none()) return result(e);
else return congr_funs(simplify(f), args);
else return congr_funs(r_f, args);
}
/* Proving */
optional<expr> simplifier::prove(expr const & thm) {
flet<name> set_name(m_rel, get_iff_name());
result r_cond = simplify(thm, true);
if (is_constant(r_cond.get_new()) && const_name(r_cond.get_new()) == get_true_name()) {
expr pf = m_app_builder.mk_app(get_iff_elim_right_name(),
finalize(r_cond).get_proof(),
mk_constant(get_true_intro_name()));
return some_expr(pf);
}
return none_expr();
}
/* Rewriting */
@ -576,7 +600,7 @@ result simplifier::try_congr(expr const & e, congr_rule const & cr) {
h_lhs = tmp_tctx->instantiate_uvars_mvars(h_lhs);
lean_assert(!has_metavar(h_lhs));
result r_congr_hyp = simplify(h_lhs);
result r_congr_hyp = simplify(h_lhs, true);
expr hyp;
if (r_congr_hyp.is_none()) {
hyp = finalize(r_congr_hyp).get_proof();
@ -633,11 +657,8 @@ bool simplifier::instantiate_emetas(blast_tmp_type_context & tmp_tctx, unsigned
if (tmp_tctx->is_mvar_assigned(i)) return;
if (tmp_tctx->is_prop(m_type)) {
flet<name> set_name(m_rel, get_iff_name());
result r_cond = simplify(m_type);
if (is_constant(r_cond.get_new()) && const_name(r_cond.get_new()) == get_true_name()) {
expr pf = m_app_builder.mk_app(name("iff", "elim_right"), finalize(r_cond).get_proof(), mk_constant(get_true_intro_name()));
lean_verify(tmp_tctx->is_def_eq(m, pf));
if (auto pf = prove(m_type)) {
lean_verify(tmp_tctx->is_def_eq(m, *pf));
return;
}
}
@ -653,6 +674,44 @@ bool simplifier::instantiate_emetas(blast_tmp_type_context & tmp_tctx, unsigned
return !failed;
}
template<typename F>
optional<result> simplifier::synth_congr(expr const & e, F && simp) {
static_assert(std::is_same<typename std::result_of<F(expr const & e)>::type, result>::value,
"synth_congr: simp must take expressions to results");
lean_assert(is_app(e));
buffer<expr> args;
expr f = get_app_args(e, args);
auto congr_lemma = mk_congr_lemma_for_simp(f, args.size());
if (!congr_lemma) return optional<result>();
expr proof = congr_lemma->get_proof();
expr type = congr_lemma->get_type();
unsigned i = 0;
bool simplified = false;
buffer<expr> locals;
for_each(congr_lemma->get_arg_kinds(), [&](congr_arg_kind const & ckind) {
proof = mk_app(proof, args[i]);
type = instantiate(binding_body(type), args[i]);
if (ckind == congr_arg_kind::Eq) {
result r_arg = simp(args[i]);
if (!r_arg.is_none()) simplified = true;
r_arg = finalize(r_arg);
proof = mk_app(proof, r_arg.get_new(), r_arg.get_proof());
type = instantiate(binding_body(type), r_arg.get_new());
type = instantiate(binding_body(type), r_arg.get_proof());
}
i++;
});
if (simplified) {
lean_assert(is_eq(type));
buffer<expr> type_args;
get_app_args(type, type_args);
expr & new_e = type_args[2];
return optional<result>(result(new_e, proof));
} else {
return optional<result>(result(e));
}
}
/* Setup and teardown */
void initialize_simplifier() {

View file

@ -4,6 +4,7 @@
#include "util/name.h"
namespace lean{
name const * g_absurd = nullptr;
name const * g_add = nullptr;
name const * g_and = nullptr;
name const * g_and_elim_left = nullptr;
name const * g_and_elim_right = nullptr;
@ -53,6 +54,8 @@ name const * g_iff_symm = nullptr;
name const * g_iff_trans = nullptr;
name const * g_iff_mp = nullptr;
name const * g_iff_mpr = nullptr;
name const * g_iff_elim_left = nullptr;
name const * g_iff_elim_right = nullptr;
name const * g_iff_false_intro = nullptr;
name const * g_iff_true_intro = nullptr;
name const * g_implies = nullptr;
@ -63,10 +66,12 @@ name const * g_ite = nullptr;
name const * g_lift = nullptr;
name const * g_lift_down = nullptr;
name const * g_lift_up = nullptr;
name const * g_mul = nullptr;
name const * g_nat = nullptr;
name const * g_nat_of_num = nullptr;
name const * g_nat_succ = nullptr;
name const * g_nat_zero = nullptr;
name const * g_neg = nullptr;
name const * g_not = nullptr;
name const * g_num = nullptr;
name const * g_num_zero = nullptr;
@ -168,6 +173,7 @@ name const * g_well_founded = nullptr;
name const * g_zero = nullptr;
void initialize_constants() {
g_absurd = new name{"absurd"};
g_add = new name{"add"};
g_and = new name{"and"};
g_and_elim_left = new name{"and", "elim_left"};
g_and_elim_right = new name{"and", "elim_right"};
@ -217,6 +223,8 @@ void initialize_constants() {
g_iff_trans = new name{"iff", "trans"};
g_iff_mp = new name{"iff", "mp"};
g_iff_mpr = new name{"iff", "mpr"};
g_iff_elim_left = new name{"iff", "elim_left"};
g_iff_elim_right = new name{"iff", "elim_right"};
g_iff_false_intro = new name{"iff_false_intro"};
g_iff_true_intro = new name{"iff_true_intro"};
g_implies = new name{"implies"};
@ -227,10 +235,12 @@ void initialize_constants() {
g_lift = new name{"lift"};
g_lift_down = new name{"lift", "down"};
g_lift_up = new name{"lift", "up"};
g_mul = new name{"mul"};
g_nat = new name{"nat"};
g_nat_of_num = new name{"nat", "of_num"};
g_nat_succ = new name{"nat", "succ"};
g_nat_zero = new name{"nat", "zero"};
g_neg = new name{"neg"};
g_not = new name{"not"};
g_num = new name{"num"};
g_num_zero = new name{"num", "zero"};
@ -333,6 +343,7 @@ void initialize_constants() {
}
void finalize_constants() {
delete g_absurd;
delete g_add;
delete g_and;
delete g_and_elim_left;
delete g_and_elim_right;
@ -382,6 +393,8 @@ void finalize_constants() {
delete g_iff_trans;
delete g_iff_mp;
delete g_iff_mpr;
delete g_iff_elim_left;
delete g_iff_elim_right;
delete g_iff_false_intro;
delete g_iff_true_intro;
delete g_implies;
@ -392,10 +405,12 @@ void finalize_constants() {
delete g_lift;
delete g_lift_down;
delete g_lift_up;
delete g_mul;
delete g_nat;
delete g_nat_of_num;
delete g_nat_succ;
delete g_nat_zero;
delete g_neg;
delete g_not;
delete g_num;
delete g_num_zero;
@ -497,6 +512,7 @@ void finalize_constants() {
delete g_zero;
}
name const & get_absurd_name() { return *g_absurd; }
name const & get_add_name() { return *g_add; }
name const & get_and_name() { return *g_and; }
name const & get_and_elim_left_name() { return *g_and_elim_left; }
name const & get_and_elim_right_name() { return *g_and_elim_right; }
@ -546,6 +562,8 @@ name const & get_iff_symm_name() { return *g_iff_symm; }
name const & get_iff_trans_name() { return *g_iff_trans; }
name const & get_iff_mp_name() { return *g_iff_mp; }
name const & get_iff_mpr_name() { return *g_iff_mpr; }
name const & get_iff_elim_left_name() { return *g_iff_elim_left; }
name const & get_iff_elim_right_name() { return *g_iff_elim_right; }
name const & get_iff_false_intro_name() { return *g_iff_false_intro; }
name const & get_iff_true_intro_name() { return *g_iff_true_intro; }
name const & get_implies_name() { return *g_implies; }
@ -556,10 +574,12 @@ name const & get_ite_name() { return *g_ite; }
name const & get_lift_name() { return *g_lift; }
name const & get_lift_down_name() { return *g_lift_down; }
name const & get_lift_up_name() { return *g_lift_up; }
name const & get_mul_name() { return *g_mul; }
name const & get_nat_name() { return *g_nat; }
name const & get_nat_of_num_name() { return *g_nat_of_num; }
name const & get_nat_succ_name() { return *g_nat_succ; }
name const & get_nat_zero_name() { return *g_nat_zero; }
name const & get_neg_name() { return *g_neg; }
name const & get_not_name() { return *g_not; }
name const & get_num_name() { return *g_num; }
name const & get_num_zero_name() { return *g_num_zero; }

View file

@ -6,6 +6,7 @@ namespace lean {
void initialize_constants();
void finalize_constants();
name const & get_absurd_name();
name const & get_add_name();
name const & get_and_name();
name const & get_and_elim_left_name();
name const & get_and_elim_right_name();
@ -55,6 +56,8 @@ name const & get_iff_symm_name();
name const & get_iff_trans_name();
name const & get_iff_mp_name();
name const & get_iff_mpr_name();
name const & get_iff_elim_left_name();
name const & get_iff_elim_right_name();
name const & get_iff_false_intro_name();
name const & get_iff_true_intro_name();
name const & get_implies_name();
@ -65,10 +68,12 @@ name const & get_ite_name();
name const & get_lift_name();
name const & get_lift_down_name();
name const & get_lift_up_name();
name const & get_mul_name();
name const & get_nat_name();
name const & get_nat_of_num_name();
name const & get_nat_succ_name();
name const & get_nat_zero_name();
name const & get_neg_name();
name const & get_not_name();
name const & get_num_name();
name const & get_num_zero_name();

View file

@ -1,4 +1,5 @@
absurd
add
and
and.elim_left
and.elim_right
@ -48,6 +49,8 @@ iff.symm
iff.trans
iff.mp
iff.mpr
iff.elim_left
iff.elim_right
iff_false_intro
iff_true_intro
implies
@ -58,10 +61,12 @@ ite
lift
lift.down
lift.up
mul
nat
nat.of_num
nat.succ
nat.zero
neg
not
num
num.zero