feat(library/blast/fusion): refactor
This commit is contained in:
parent
d852be0d79
commit
f72da014d4
4 changed files with 140 additions and 51 deletions
|
@ -21,6 +21,7 @@ Author: Daniel Selsam
|
|||
#include "util/pair.h"
|
||||
#include "util/sexpr/option_declarations.h"
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
|
||||
#ifndef LEAN_DEFAULT_SIMPLIFY_MAX_STEPS
|
||||
#define LEAN_DEFAULT_SIMPLIFY_MAX_STEPS 1000
|
||||
|
@ -95,6 +96,25 @@ bool get_simplify_fuse() {
|
|||
return ios().get_options().get_bool(*g_simplify_fuse, LEAN_DEFAULT_SIMPLIFY_FUSE);
|
||||
}
|
||||
|
||||
/* Miscellaneous helpers */
|
||||
|
||||
static bool is_const_app(expr const & e, name const & n, unsigned nargs) {
|
||||
expr const & f = get_app_fn(e);
|
||||
return is_constant(f) && const_name(f) == n && get_app_num_args(e) == nargs;
|
||||
}
|
||||
|
||||
static bool is_add_app(expr const & e) {
|
||||
return is_const_app(e, get_add_name(), 4);
|
||||
}
|
||||
|
||||
static bool is_mul_app(expr const & e) {
|
||||
return is_const_app(e, get_mul_name(), 4);
|
||||
}
|
||||
|
||||
static bool is_neg_app(expr const & e) {
|
||||
return is_const_app(e, get_neg_name(), 3);
|
||||
}
|
||||
|
||||
/* Main simplifier class */
|
||||
|
||||
class simplifier {
|
||||
|
@ -166,6 +186,10 @@ class simplifier {
|
|||
return srss;
|
||||
}
|
||||
|
||||
|
||||
bool instantiate_emetas(blast_tmp_type_context & tmp_tctx, unsigned num_emeta,
|
||||
list<expr> const & emetas, list<bool> const & instances);
|
||||
|
||||
/* Results */
|
||||
result lift_from_eq(expr const & x, result const & r);
|
||||
result join(result const & r1, result const & r2);
|
||||
|
@ -173,12 +197,15 @@ class simplifier {
|
|||
result finalize(result const & r);
|
||||
|
||||
/* Simplification */
|
||||
result simplify(expr const & e);
|
||||
result simplify(expr const & e, bool is_root);
|
||||
result simplify_lambda(expr const & e);
|
||||
result simplify_pi(expr const & e);
|
||||
result simplify_app(expr const & e);
|
||||
result simplify_fun(expr const & e);
|
||||
|
||||
/* Proving */
|
||||
optional<expr> prove(expr const & thm);
|
||||
|
||||
/* Rewriting */
|
||||
result rewrite(expr const & e);
|
||||
result rewrite(expr const & e, simp_rule_sets const & srss);
|
||||
|
@ -193,12 +220,20 @@ class simplifier {
|
|||
result try_congrs(expr const & e);
|
||||
result try_congr(expr const & e, congr_rule const & cr);
|
||||
|
||||
bool instantiate_emetas(blast_tmp_type_context & tmp_tctx, unsigned num_emeta,
|
||||
list<expr> const & emetas, list<bool> const & instances);
|
||||
template<typename F>
|
||||
optional<result> synth_congr(expr const & e, F && simp);
|
||||
|
||||
/* Fusion */
|
||||
std::array<bool, 1> tc_mask1{{true}};
|
||||
std::array<bool, 2> tc_mask2{{true, false}};
|
||||
result maybe_fuse(expr const & e, bool is_root);
|
||||
result fuse(expr const & e);
|
||||
expr_pair split_summand(expr const & e, expr const & f_mul, expr const & one);
|
||||
|
||||
|
||||
public:
|
||||
simplifier(name const & rel, simp_rule_sets const & srss);
|
||||
result operator()(expr const & e) { return simplify(e); }
|
||||
result operator()(expr const & e) { return simplify(e, true); }
|
||||
};
|
||||
|
||||
/* Constructor */
|
||||
|
@ -257,7 +292,7 @@ result simplifier::finalize(result const & r) {
|
|||
|
||||
/* Simplification */
|
||||
|
||||
result simplifier::simplify(expr const & e) {
|
||||
result simplifier::simplify(expr const & e, bool is_root) {
|
||||
m_num_steps++;
|
||||
flet<unsigned> inc_depth(m_depth, m_depth+1);
|
||||
|
||||
|
@ -289,7 +324,7 @@ result simplifier::simplify(expr const & e) {
|
|||
lean_unreachable();
|
||||
case expr_kind::Macro:
|
||||
if (m_expand_macros) {
|
||||
if (auto m = m_tmp_tctx->expand_macro(e)) r = join(r, simplify(whnf(*m)));
|
||||
if (auto m = m_tmp_tctx->expand_macro(e)) r = join(r, simplify(whnf(*m), is_root));
|
||||
}
|
||||
break;
|
||||
case expr_kind::Lambda:
|
||||
|
@ -308,12 +343,12 @@ result simplifier::simplify(expr const & e) {
|
|||
if (r.get_new() == e && !using_eq()) {
|
||||
{
|
||||
flet<name> use_eq(m_rel, get_eq_name());
|
||||
r = simplify(r.get_new());
|
||||
r = simplify(r.get_new(), is_root);
|
||||
}
|
||||
if (!r.is_none()) r = lift_from_eq(e, r);
|
||||
}
|
||||
|
||||
if (m_exhaustive && r.get_new() != e) r = join(r, simplify(r.get_new()));
|
||||
if (m_exhaustive && r.get_new() != e) r = join(r, simplify(r.get_new(), is_root));
|
||||
|
||||
if (m_memoize) cache_save(e, r);
|
||||
|
||||
|
@ -332,7 +367,7 @@ result simplifier::simplify_lambda(expr const & _e) {
|
|||
e = instantiate(binding_body(e), l);
|
||||
}
|
||||
|
||||
result r = simplify(e);
|
||||
result r = simplify(e, false);
|
||||
if (r.is_none()) { return result(_e); }
|
||||
|
||||
for (int i = ls.size() - 1; i >= 0; --i) r = funext(r, ls[i]);
|
||||
|
@ -357,38 +392,10 @@ result simplifier::simplify_app(expr const & e) {
|
|||
|
||||
/* (2) Synthesize congruence lemma */
|
||||
if (using_eq()) {
|
||||
buffer<expr> args;
|
||||
expr fn = get_app_args(e, args);
|
||||
if (auto congr_lemma = mk_congr_lemma_for_simp(fn, args.size())) {
|
||||
expr proof = congr_lemma->get_proof();
|
||||
expr type = congr_lemma->get_type();
|
||||
unsigned i = 0;
|
||||
bool simplified = false;
|
||||
buffer<expr> locals;
|
||||
for_each(congr_lemma->get_arg_kinds(), [&](congr_arg_kind const & ckind) {
|
||||
proof = mk_app(proof, args[i]);
|
||||
type = instantiate(binding_body(type), args[i]);
|
||||
|
||||
if (ckind == congr_arg_kind::Eq) {
|
||||
result r_arg = simplify(args[i]);
|
||||
if (!r_arg.is_none()) simplified = true;
|
||||
r_arg = finalize(r_arg);
|
||||
proof = mk_app(proof, r_arg.get_new(), r_arg.get_proof());
|
||||
type = instantiate(binding_body(type), r_arg.get_new());
|
||||
type = instantiate(binding_body(type), r_arg.get_proof());
|
||||
}
|
||||
i++;
|
||||
});
|
||||
if (simplified) {
|
||||
lean_assert(is_eq(type));
|
||||
buffer<expr> type_args;
|
||||
get_app_args(type, type_args);
|
||||
expr & new_e = type_args[2];
|
||||
return join(result(new_e, proof), simplify_fun(new_e));
|
||||
} else {
|
||||
return simplify_fun(e);
|
||||
}
|
||||
}
|
||||
optional<result> r_args = synth_congr(e, [&](expr const & e) {
|
||||
return simplify(e, false);
|
||||
});
|
||||
if (r_args) return join(*r_args, simplify_fun(r_args->get_new()));
|
||||
}
|
||||
|
||||
/* (3) Fall back on generic binary congruence */
|
||||
|
@ -396,13 +403,16 @@ result simplifier::simplify_app(expr const & e) {
|
|||
expr const & f = app_fn(e);
|
||||
expr const & arg = app_arg(e);
|
||||
|
||||
result r_f = simplify(f);
|
||||
// TODO(dhs): it is not clear if this recursive call should be considered
|
||||
// a root or not, though does not matter since if + were being applied,
|
||||
// we would have synthesized a congruence rule in step (2).
|
||||
result r_f = simplify(f, false);
|
||||
|
||||
if (is_dependent_fn(f)) {
|
||||
if (r_f.is_none()) return e;
|
||||
else return congr_fun(r_f, arg);
|
||||
} else {
|
||||
result r_arg = simplify(arg);
|
||||
result r_arg = simplify(arg, false);
|
||||
if (r_f.is_none() && r_arg.is_none()) return e;
|
||||
else if (r_f.is_none()) return congr_arg(f, r_arg);
|
||||
else if (r_arg.is_none()) return congr_fun(r_f, arg);
|
||||
|
@ -417,9 +427,23 @@ result simplifier::simplify_fun(expr const & e) {
|
|||
lean_assert(is_app(e));
|
||||
buffer<expr> args;
|
||||
expr const & f = get_app_args(e, args);
|
||||
result r_f = simplify(f);
|
||||
result r_f = simplify(f, true);
|
||||
if (r_f.is_none()) return result(e);
|
||||
else return congr_funs(simplify(f), args);
|
||||
else return congr_funs(r_f, args);
|
||||
}
|
||||
|
||||
/* Proving */
|
||||
|
||||
optional<expr> simplifier::prove(expr const & thm) {
|
||||
flet<name> set_name(m_rel, get_iff_name());
|
||||
result r_cond = simplify(thm, true);
|
||||
if (is_constant(r_cond.get_new()) && const_name(r_cond.get_new()) == get_true_name()) {
|
||||
expr pf = m_app_builder.mk_app(get_iff_elim_right_name(),
|
||||
finalize(r_cond).get_proof(),
|
||||
mk_constant(get_true_intro_name()));
|
||||
return some_expr(pf);
|
||||
}
|
||||
return none_expr();
|
||||
}
|
||||
|
||||
/* Rewriting */
|
||||
|
@ -576,7 +600,7 @@ result simplifier::try_congr(expr const & e, congr_rule const & cr) {
|
|||
h_lhs = tmp_tctx->instantiate_uvars_mvars(h_lhs);
|
||||
lean_assert(!has_metavar(h_lhs));
|
||||
|
||||
result r_congr_hyp = simplify(h_lhs);
|
||||
result r_congr_hyp = simplify(h_lhs, true);
|
||||
expr hyp;
|
||||
if (r_congr_hyp.is_none()) {
|
||||
hyp = finalize(r_congr_hyp).get_proof();
|
||||
|
@ -633,11 +657,8 @@ bool simplifier::instantiate_emetas(blast_tmp_type_context & tmp_tctx, unsigned
|
|||
if (tmp_tctx->is_mvar_assigned(i)) return;
|
||||
|
||||
if (tmp_tctx->is_prop(m_type)) {
|
||||
flet<name> set_name(m_rel, get_iff_name());
|
||||
result r_cond = simplify(m_type);
|
||||
if (is_constant(r_cond.get_new()) && const_name(r_cond.get_new()) == get_true_name()) {
|
||||
expr pf = m_app_builder.mk_app(name("iff", "elim_right"), finalize(r_cond).get_proof(), mk_constant(get_true_intro_name()));
|
||||
lean_verify(tmp_tctx->is_def_eq(m, pf));
|
||||
if (auto pf = prove(m_type)) {
|
||||
lean_verify(tmp_tctx->is_def_eq(m, *pf));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -653,6 +674,44 @@ bool simplifier::instantiate_emetas(blast_tmp_type_context & tmp_tctx, unsigned
|
|||
return !failed;
|
||||
}
|
||||
|
||||
template<typename F>
|
||||
optional<result> simplifier::synth_congr(expr const & e, F && simp) {
|
||||
static_assert(std::is_same<typename std::result_of<F(expr const & e)>::type, result>::value,
|
||||
"synth_congr: simp must take expressions to results");
|
||||
lean_assert(is_app(e));
|
||||
buffer<expr> args;
|
||||
expr f = get_app_args(e, args);
|
||||
auto congr_lemma = mk_congr_lemma_for_simp(f, args.size());
|
||||
if (!congr_lemma) return optional<result>();
|
||||
expr proof = congr_lemma->get_proof();
|
||||
expr type = congr_lemma->get_type();
|
||||
unsigned i = 0;
|
||||
bool simplified = false;
|
||||
buffer<expr> locals;
|
||||
for_each(congr_lemma->get_arg_kinds(), [&](congr_arg_kind const & ckind) {
|
||||
proof = mk_app(proof, args[i]);
|
||||
type = instantiate(binding_body(type), args[i]);
|
||||
if (ckind == congr_arg_kind::Eq) {
|
||||
result r_arg = simp(args[i]);
|
||||
if (!r_arg.is_none()) simplified = true;
|
||||
r_arg = finalize(r_arg);
|
||||
proof = mk_app(proof, r_arg.get_new(), r_arg.get_proof());
|
||||
type = instantiate(binding_body(type), r_arg.get_new());
|
||||
type = instantiate(binding_body(type), r_arg.get_proof());
|
||||
}
|
||||
i++;
|
||||
});
|
||||
if (simplified) {
|
||||
lean_assert(is_eq(type));
|
||||
buffer<expr> type_args;
|
||||
get_app_args(type, type_args);
|
||||
expr & new_e = type_args[2];
|
||||
return optional<result>(result(new_e, proof));
|
||||
} else {
|
||||
return optional<result>(result(e));
|
||||
}
|
||||
}
|
||||
|
||||
/* Setup and teardown */
|
||||
|
||||
void initialize_simplifier() {
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#include "util/name.h"
|
||||
namespace lean{
|
||||
name const * g_absurd = nullptr;
|
||||
name const * g_add = nullptr;
|
||||
name const * g_and = nullptr;
|
||||
name const * g_and_elim_left = nullptr;
|
||||
name const * g_and_elim_right = nullptr;
|
||||
|
@ -53,6 +54,8 @@ name const * g_iff_symm = nullptr;
|
|||
name const * g_iff_trans = nullptr;
|
||||
name const * g_iff_mp = nullptr;
|
||||
name const * g_iff_mpr = nullptr;
|
||||
name const * g_iff_elim_left = nullptr;
|
||||
name const * g_iff_elim_right = nullptr;
|
||||
name const * g_iff_false_intro = nullptr;
|
||||
name const * g_iff_true_intro = nullptr;
|
||||
name const * g_implies = nullptr;
|
||||
|
@ -63,10 +66,12 @@ name const * g_ite = nullptr;
|
|||
name const * g_lift = nullptr;
|
||||
name const * g_lift_down = nullptr;
|
||||
name const * g_lift_up = nullptr;
|
||||
name const * g_mul = nullptr;
|
||||
name const * g_nat = nullptr;
|
||||
name const * g_nat_of_num = nullptr;
|
||||
name const * g_nat_succ = nullptr;
|
||||
name const * g_nat_zero = nullptr;
|
||||
name const * g_neg = nullptr;
|
||||
name const * g_not = nullptr;
|
||||
name const * g_num = nullptr;
|
||||
name const * g_num_zero = nullptr;
|
||||
|
@ -168,6 +173,7 @@ name const * g_well_founded = nullptr;
|
|||
name const * g_zero = nullptr;
|
||||
void initialize_constants() {
|
||||
g_absurd = new name{"absurd"};
|
||||
g_add = new name{"add"};
|
||||
g_and = new name{"and"};
|
||||
g_and_elim_left = new name{"and", "elim_left"};
|
||||
g_and_elim_right = new name{"and", "elim_right"};
|
||||
|
@ -217,6 +223,8 @@ void initialize_constants() {
|
|||
g_iff_trans = new name{"iff", "trans"};
|
||||
g_iff_mp = new name{"iff", "mp"};
|
||||
g_iff_mpr = new name{"iff", "mpr"};
|
||||
g_iff_elim_left = new name{"iff", "elim_left"};
|
||||
g_iff_elim_right = new name{"iff", "elim_right"};
|
||||
g_iff_false_intro = new name{"iff_false_intro"};
|
||||
g_iff_true_intro = new name{"iff_true_intro"};
|
||||
g_implies = new name{"implies"};
|
||||
|
@ -227,10 +235,12 @@ void initialize_constants() {
|
|||
g_lift = new name{"lift"};
|
||||
g_lift_down = new name{"lift", "down"};
|
||||
g_lift_up = new name{"lift", "up"};
|
||||
g_mul = new name{"mul"};
|
||||
g_nat = new name{"nat"};
|
||||
g_nat_of_num = new name{"nat", "of_num"};
|
||||
g_nat_succ = new name{"nat", "succ"};
|
||||
g_nat_zero = new name{"nat", "zero"};
|
||||
g_neg = new name{"neg"};
|
||||
g_not = new name{"not"};
|
||||
g_num = new name{"num"};
|
||||
g_num_zero = new name{"num", "zero"};
|
||||
|
@ -333,6 +343,7 @@ void initialize_constants() {
|
|||
}
|
||||
void finalize_constants() {
|
||||
delete g_absurd;
|
||||
delete g_add;
|
||||
delete g_and;
|
||||
delete g_and_elim_left;
|
||||
delete g_and_elim_right;
|
||||
|
@ -382,6 +393,8 @@ void finalize_constants() {
|
|||
delete g_iff_trans;
|
||||
delete g_iff_mp;
|
||||
delete g_iff_mpr;
|
||||
delete g_iff_elim_left;
|
||||
delete g_iff_elim_right;
|
||||
delete g_iff_false_intro;
|
||||
delete g_iff_true_intro;
|
||||
delete g_implies;
|
||||
|
@ -392,10 +405,12 @@ void finalize_constants() {
|
|||
delete g_lift;
|
||||
delete g_lift_down;
|
||||
delete g_lift_up;
|
||||
delete g_mul;
|
||||
delete g_nat;
|
||||
delete g_nat_of_num;
|
||||
delete g_nat_succ;
|
||||
delete g_nat_zero;
|
||||
delete g_neg;
|
||||
delete g_not;
|
||||
delete g_num;
|
||||
delete g_num_zero;
|
||||
|
@ -497,6 +512,7 @@ void finalize_constants() {
|
|||
delete g_zero;
|
||||
}
|
||||
name const & get_absurd_name() { return *g_absurd; }
|
||||
name const & get_add_name() { return *g_add; }
|
||||
name const & get_and_name() { return *g_and; }
|
||||
name const & get_and_elim_left_name() { return *g_and_elim_left; }
|
||||
name const & get_and_elim_right_name() { return *g_and_elim_right; }
|
||||
|
@ -546,6 +562,8 @@ name const & get_iff_symm_name() { return *g_iff_symm; }
|
|||
name const & get_iff_trans_name() { return *g_iff_trans; }
|
||||
name const & get_iff_mp_name() { return *g_iff_mp; }
|
||||
name const & get_iff_mpr_name() { return *g_iff_mpr; }
|
||||
name const & get_iff_elim_left_name() { return *g_iff_elim_left; }
|
||||
name const & get_iff_elim_right_name() { return *g_iff_elim_right; }
|
||||
name const & get_iff_false_intro_name() { return *g_iff_false_intro; }
|
||||
name const & get_iff_true_intro_name() { return *g_iff_true_intro; }
|
||||
name const & get_implies_name() { return *g_implies; }
|
||||
|
@ -556,10 +574,12 @@ name const & get_ite_name() { return *g_ite; }
|
|||
name const & get_lift_name() { return *g_lift; }
|
||||
name const & get_lift_down_name() { return *g_lift_down; }
|
||||
name const & get_lift_up_name() { return *g_lift_up; }
|
||||
name const & get_mul_name() { return *g_mul; }
|
||||
name const & get_nat_name() { return *g_nat; }
|
||||
name const & get_nat_of_num_name() { return *g_nat_of_num; }
|
||||
name const & get_nat_succ_name() { return *g_nat_succ; }
|
||||
name const & get_nat_zero_name() { return *g_nat_zero; }
|
||||
name const & get_neg_name() { return *g_neg; }
|
||||
name const & get_not_name() { return *g_not; }
|
||||
name const & get_num_name() { return *g_num; }
|
||||
name const & get_num_zero_name() { return *g_num_zero; }
|
||||
|
|
|
@ -6,6 +6,7 @@ namespace lean {
|
|||
void initialize_constants();
|
||||
void finalize_constants();
|
||||
name const & get_absurd_name();
|
||||
name const & get_add_name();
|
||||
name const & get_and_name();
|
||||
name const & get_and_elim_left_name();
|
||||
name const & get_and_elim_right_name();
|
||||
|
@ -55,6 +56,8 @@ name const & get_iff_symm_name();
|
|||
name const & get_iff_trans_name();
|
||||
name const & get_iff_mp_name();
|
||||
name const & get_iff_mpr_name();
|
||||
name const & get_iff_elim_left_name();
|
||||
name const & get_iff_elim_right_name();
|
||||
name const & get_iff_false_intro_name();
|
||||
name const & get_iff_true_intro_name();
|
||||
name const & get_implies_name();
|
||||
|
@ -65,10 +68,12 @@ name const & get_ite_name();
|
|||
name const & get_lift_name();
|
||||
name const & get_lift_down_name();
|
||||
name const & get_lift_up_name();
|
||||
name const & get_mul_name();
|
||||
name const & get_nat_name();
|
||||
name const & get_nat_of_num_name();
|
||||
name const & get_nat_succ_name();
|
||||
name const & get_nat_zero_name();
|
||||
name const & get_neg_name();
|
||||
name const & get_not_name();
|
||||
name const & get_num_name();
|
||||
name const & get_num_zero_name();
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
absurd
|
||||
add
|
||||
and
|
||||
and.elim_left
|
||||
and.elim_right
|
||||
|
@ -48,6 +49,8 @@ iff.symm
|
|||
iff.trans
|
||||
iff.mp
|
||||
iff.mpr
|
||||
iff.elim_left
|
||||
iff.elim_right
|
||||
iff_false_intro
|
||||
iff_true_intro
|
||||
implies
|
||||
|
@ -58,10 +61,12 @@ ite
|
|||
lift
|
||||
lift.down
|
||||
lift.up
|
||||
mul
|
||||
nat
|
||||
nat.of_num
|
||||
nat.succ
|
||||
nat.zero
|
||||
neg
|
||||
not
|
||||
num
|
||||
num.zero
|
||||
|
|
Loading…
Reference in a new issue