diff --git a/src/library/blast/simplifier.cpp b/src/library/blast/simplifier.cpp index 821bf1069..7329257dc 100644 --- a/src/library/blast/simplifier.cpp +++ b/src/library/blast/simplifier.cpp @@ -21,6 +21,7 @@ Author: Daniel Selsam #include "util/pair.h" #include "util/sexpr/option_declarations.h" #include +#include #ifndef LEAN_DEFAULT_SIMPLIFY_MAX_STEPS #define LEAN_DEFAULT_SIMPLIFY_MAX_STEPS 1000 @@ -95,6 +96,25 @@ bool get_simplify_fuse() { return ios().get_options().get_bool(*g_simplify_fuse, LEAN_DEFAULT_SIMPLIFY_FUSE); } +/* Miscellaneous helpers */ + +static bool is_const_app(expr const & e, name const & n, unsigned nargs) { + expr const & f = get_app_fn(e); + return is_constant(f) && const_name(f) == n && get_app_num_args(e) == nargs; +} + +static bool is_add_app(expr const & e) { + return is_const_app(e, get_add_name(), 4); +} + +static bool is_mul_app(expr const & e) { + return is_const_app(e, get_mul_name(), 4); +} + +static bool is_neg_app(expr const & e) { + return is_const_app(e, get_neg_name(), 3); +} + /* Main simplifier class */ class simplifier { @@ -166,6 +186,10 @@ class simplifier { return srss; } + + bool instantiate_emetas(blast_tmp_type_context & tmp_tctx, unsigned num_emeta, + list const & emetas, list const & instances); + /* Results */ result lift_from_eq(expr const & x, result const & r); result join(result const & r1, result const & r2); @@ -173,12 +197,15 @@ class simplifier { result finalize(result const & r); /* Simplification */ - result simplify(expr const & e); + result simplify(expr const & e, bool is_root); result simplify_lambda(expr const & e); result simplify_pi(expr const & e); result simplify_app(expr const & e); result simplify_fun(expr const & e); + /* Proving */ + optional prove(expr const & thm); + /* Rewriting */ result rewrite(expr const & e); result rewrite(expr const & e, simp_rule_sets const & srss); @@ -193,12 +220,20 @@ class simplifier { result try_congrs(expr const & e); result try_congr(expr const & e, congr_rule const & cr); - bool instantiate_emetas(blast_tmp_type_context & tmp_tctx, unsigned num_emeta, - list const & emetas, list const & instances); + template + optional synth_congr(expr const & e, F && simp); + + /* Fusion */ + std::array tc_mask1{{true}}; + std::array tc_mask2{{true, false}}; + result maybe_fuse(expr const & e, bool is_root); + result fuse(expr const & e); + expr_pair split_summand(expr const & e, expr const & f_mul, expr const & one); + public: simplifier(name const & rel, simp_rule_sets const & srss); - result operator()(expr const & e) { return simplify(e); } + result operator()(expr const & e) { return simplify(e, true); } }; /* Constructor */ @@ -257,7 +292,7 @@ result simplifier::finalize(result const & r) { /* Simplification */ -result simplifier::simplify(expr const & e) { +result simplifier::simplify(expr const & e, bool is_root) { m_num_steps++; flet inc_depth(m_depth, m_depth+1); @@ -289,7 +324,7 @@ result simplifier::simplify(expr const & e) { lean_unreachable(); case expr_kind::Macro: if (m_expand_macros) { - if (auto m = m_tmp_tctx->expand_macro(e)) r = join(r, simplify(whnf(*m))); + if (auto m = m_tmp_tctx->expand_macro(e)) r = join(r, simplify(whnf(*m), is_root)); } break; case expr_kind::Lambda: @@ -308,12 +343,12 @@ result simplifier::simplify(expr const & e) { if (r.get_new() == e && !using_eq()) { { flet use_eq(m_rel, get_eq_name()); - r = simplify(r.get_new()); + r = simplify(r.get_new(), is_root); } if (!r.is_none()) r = lift_from_eq(e, r); } - if (m_exhaustive && r.get_new() != e) r = join(r, simplify(r.get_new())); + if (m_exhaustive && r.get_new() != e) r = join(r, simplify(r.get_new(), is_root)); if (m_memoize) cache_save(e, r); @@ -332,7 +367,7 @@ result simplifier::simplify_lambda(expr const & _e) { e = instantiate(binding_body(e), l); } - result r = simplify(e); + result r = simplify(e, false); if (r.is_none()) { return result(_e); } for (int i = ls.size() - 1; i >= 0; --i) r = funext(r, ls[i]); @@ -357,38 +392,10 @@ result simplifier::simplify_app(expr const & e) { /* (2) Synthesize congruence lemma */ if (using_eq()) { - buffer args; - expr fn = get_app_args(e, args); - if (auto congr_lemma = mk_congr_lemma_for_simp(fn, args.size())) { - expr proof = congr_lemma->get_proof(); - expr type = congr_lemma->get_type(); - unsigned i = 0; - bool simplified = false; - buffer locals; - for_each(congr_lemma->get_arg_kinds(), [&](congr_arg_kind const & ckind) { - proof = mk_app(proof, args[i]); - type = instantiate(binding_body(type), args[i]); - - if (ckind == congr_arg_kind::Eq) { - result r_arg = simplify(args[i]); - if (!r_arg.is_none()) simplified = true; - r_arg = finalize(r_arg); - proof = mk_app(proof, r_arg.get_new(), r_arg.get_proof()); - type = instantiate(binding_body(type), r_arg.get_new()); - type = instantiate(binding_body(type), r_arg.get_proof()); - } - i++; - }); - if (simplified) { - lean_assert(is_eq(type)); - buffer type_args; - get_app_args(type, type_args); - expr & new_e = type_args[2]; - return join(result(new_e, proof), simplify_fun(new_e)); - } else { - return simplify_fun(e); - } - } + optional r_args = synth_congr(e, [&](expr const & e) { + return simplify(e, false); + }); + if (r_args) return join(*r_args, simplify_fun(r_args->get_new())); } /* (3) Fall back on generic binary congruence */ @@ -396,13 +403,16 @@ result simplifier::simplify_app(expr const & e) { expr const & f = app_fn(e); expr const & arg = app_arg(e); - result r_f = simplify(f); + // TODO(dhs): it is not clear if this recursive call should be considered + // a root or not, though does not matter since if + were being applied, + // we would have synthesized a congruence rule in step (2). + result r_f = simplify(f, false); if (is_dependent_fn(f)) { if (r_f.is_none()) return e; else return congr_fun(r_f, arg); } else { - result r_arg = simplify(arg); + result r_arg = simplify(arg, false); if (r_f.is_none() && r_arg.is_none()) return e; else if (r_f.is_none()) return congr_arg(f, r_arg); else if (r_arg.is_none()) return congr_fun(r_f, arg); @@ -417,9 +427,23 @@ result simplifier::simplify_fun(expr const & e) { lean_assert(is_app(e)); buffer args; expr const & f = get_app_args(e, args); - result r_f = simplify(f); + result r_f = simplify(f, true); if (r_f.is_none()) return result(e); - else return congr_funs(simplify(f), args); + else return congr_funs(r_f, args); +} + +/* Proving */ + +optional simplifier::prove(expr const & thm) { + flet set_name(m_rel, get_iff_name()); + result r_cond = simplify(thm, true); + if (is_constant(r_cond.get_new()) && const_name(r_cond.get_new()) == get_true_name()) { + expr pf = m_app_builder.mk_app(get_iff_elim_right_name(), + finalize(r_cond).get_proof(), + mk_constant(get_true_intro_name())); + return some_expr(pf); + } + return none_expr(); } /* Rewriting */ @@ -576,7 +600,7 @@ result simplifier::try_congr(expr const & e, congr_rule const & cr) { h_lhs = tmp_tctx->instantiate_uvars_mvars(h_lhs); lean_assert(!has_metavar(h_lhs)); - result r_congr_hyp = simplify(h_lhs); + result r_congr_hyp = simplify(h_lhs, true); expr hyp; if (r_congr_hyp.is_none()) { hyp = finalize(r_congr_hyp).get_proof(); @@ -633,11 +657,8 @@ bool simplifier::instantiate_emetas(blast_tmp_type_context & tmp_tctx, unsigned if (tmp_tctx->is_mvar_assigned(i)) return; if (tmp_tctx->is_prop(m_type)) { - flet set_name(m_rel, get_iff_name()); - result r_cond = simplify(m_type); - if (is_constant(r_cond.get_new()) && const_name(r_cond.get_new()) == get_true_name()) { - expr pf = m_app_builder.mk_app(name("iff", "elim_right"), finalize(r_cond).get_proof(), mk_constant(get_true_intro_name())); - lean_verify(tmp_tctx->is_def_eq(m, pf)); + if (auto pf = prove(m_type)) { + lean_verify(tmp_tctx->is_def_eq(m, *pf)); return; } } @@ -653,6 +674,44 @@ bool simplifier::instantiate_emetas(blast_tmp_type_context & tmp_tctx, unsigned return !failed; } +template +optional simplifier::synth_congr(expr const & e, F && simp) { + static_assert(std::is_same::type, result>::value, + "synth_congr: simp must take expressions to results"); + lean_assert(is_app(e)); + buffer args; + expr f = get_app_args(e, args); + auto congr_lemma = mk_congr_lemma_for_simp(f, args.size()); + if (!congr_lemma) return optional(); + expr proof = congr_lemma->get_proof(); + expr type = congr_lemma->get_type(); + unsigned i = 0; + bool simplified = false; + buffer locals; + for_each(congr_lemma->get_arg_kinds(), [&](congr_arg_kind const & ckind) { + proof = mk_app(proof, args[i]); + type = instantiate(binding_body(type), args[i]); + if (ckind == congr_arg_kind::Eq) { + result r_arg = simp(args[i]); + if (!r_arg.is_none()) simplified = true; + r_arg = finalize(r_arg); + proof = mk_app(proof, r_arg.get_new(), r_arg.get_proof()); + type = instantiate(binding_body(type), r_arg.get_new()); + type = instantiate(binding_body(type), r_arg.get_proof()); + } + i++; + }); + if (simplified) { + lean_assert(is_eq(type)); + buffer type_args; + get_app_args(type, type_args); + expr & new_e = type_args[2]; + return optional(result(new_e, proof)); + } else { + return optional(result(e)); + } +} + /* Setup and teardown */ void initialize_simplifier() { diff --git a/src/library/constants.cpp b/src/library/constants.cpp index 74b358285..b03cee53d 100644 --- a/src/library/constants.cpp +++ b/src/library/constants.cpp @@ -4,6 +4,7 @@ #include "util/name.h" namespace lean{ name const * g_absurd = nullptr; +name const * g_add = nullptr; name const * g_and = nullptr; name const * g_and_elim_left = nullptr; name const * g_and_elim_right = nullptr; @@ -53,6 +54,8 @@ name const * g_iff_symm = nullptr; name const * g_iff_trans = nullptr; name const * g_iff_mp = nullptr; name const * g_iff_mpr = nullptr; +name const * g_iff_elim_left = nullptr; +name const * g_iff_elim_right = nullptr; name const * g_iff_false_intro = nullptr; name const * g_iff_true_intro = nullptr; name const * g_implies = nullptr; @@ -63,10 +66,12 @@ name const * g_ite = nullptr; name const * g_lift = nullptr; name const * g_lift_down = nullptr; name const * g_lift_up = nullptr; +name const * g_mul = nullptr; name const * g_nat = nullptr; name const * g_nat_of_num = nullptr; name const * g_nat_succ = nullptr; name const * g_nat_zero = nullptr; +name const * g_neg = nullptr; name const * g_not = nullptr; name const * g_num = nullptr; name const * g_num_zero = nullptr; @@ -168,6 +173,7 @@ name const * g_well_founded = nullptr; name const * g_zero = nullptr; void initialize_constants() { g_absurd = new name{"absurd"}; + g_add = new name{"add"}; g_and = new name{"and"}; g_and_elim_left = new name{"and", "elim_left"}; g_and_elim_right = new name{"and", "elim_right"}; @@ -217,6 +223,8 @@ void initialize_constants() { g_iff_trans = new name{"iff", "trans"}; g_iff_mp = new name{"iff", "mp"}; g_iff_mpr = new name{"iff", "mpr"}; + g_iff_elim_left = new name{"iff", "elim_left"}; + g_iff_elim_right = new name{"iff", "elim_right"}; g_iff_false_intro = new name{"iff_false_intro"}; g_iff_true_intro = new name{"iff_true_intro"}; g_implies = new name{"implies"}; @@ -227,10 +235,12 @@ void initialize_constants() { g_lift = new name{"lift"}; g_lift_down = new name{"lift", "down"}; g_lift_up = new name{"lift", "up"}; + g_mul = new name{"mul"}; g_nat = new name{"nat"}; g_nat_of_num = new name{"nat", "of_num"}; g_nat_succ = new name{"nat", "succ"}; g_nat_zero = new name{"nat", "zero"}; + g_neg = new name{"neg"}; g_not = new name{"not"}; g_num = new name{"num"}; g_num_zero = new name{"num", "zero"}; @@ -333,6 +343,7 @@ void initialize_constants() { } void finalize_constants() { delete g_absurd; + delete g_add; delete g_and; delete g_and_elim_left; delete g_and_elim_right; @@ -382,6 +393,8 @@ void finalize_constants() { delete g_iff_trans; delete g_iff_mp; delete g_iff_mpr; + delete g_iff_elim_left; + delete g_iff_elim_right; delete g_iff_false_intro; delete g_iff_true_intro; delete g_implies; @@ -392,10 +405,12 @@ void finalize_constants() { delete g_lift; delete g_lift_down; delete g_lift_up; + delete g_mul; delete g_nat; delete g_nat_of_num; delete g_nat_succ; delete g_nat_zero; + delete g_neg; delete g_not; delete g_num; delete g_num_zero; @@ -497,6 +512,7 @@ void finalize_constants() { delete g_zero; } name const & get_absurd_name() { return *g_absurd; } +name const & get_add_name() { return *g_add; } name const & get_and_name() { return *g_and; } name const & get_and_elim_left_name() { return *g_and_elim_left; } name const & get_and_elim_right_name() { return *g_and_elim_right; } @@ -546,6 +562,8 @@ name const & get_iff_symm_name() { return *g_iff_symm; } name const & get_iff_trans_name() { return *g_iff_trans; } name const & get_iff_mp_name() { return *g_iff_mp; } name const & get_iff_mpr_name() { return *g_iff_mpr; } +name const & get_iff_elim_left_name() { return *g_iff_elim_left; } +name const & get_iff_elim_right_name() { return *g_iff_elim_right; } name const & get_iff_false_intro_name() { return *g_iff_false_intro; } name const & get_iff_true_intro_name() { return *g_iff_true_intro; } name const & get_implies_name() { return *g_implies; } @@ -556,10 +574,12 @@ name const & get_ite_name() { return *g_ite; } name const & get_lift_name() { return *g_lift; } name const & get_lift_down_name() { return *g_lift_down; } name const & get_lift_up_name() { return *g_lift_up; } +name const & get_mul_name() { return *g_mul; } name const & get_nat_name() { return *g_nat; } name const & get_nat_of_num_name() { return *g_nat_of_num; } name const & get_nat_succ_name() { return *g_nat_succ; } name const & get_nat_zero_name() { return *g_nat_zero; } +name const & get_neg_name() { return *g_neg; } name const & get_not_name() { return *g_not; } name const & get_num_name() { return *g_num; } name const & get_num_zero_name() { return *g_num_zero; } diff --git a/src/library/constants.h b/src/library/constants.h index 12d9a7647..a838ca55c 100644 --- a/src/library/constants.h +++ b/src/library/constants.h @@ -6,6 +6,7 @@ namespace lean { void initialize_constants(); void finalize_constants(); name const & get_absurd_name(); +name const & get_add_name(); name const & get_and_name(); name const & get_and_elim_left_name(); name const & get_and_elim_right_name(); @@ -55,6 +56,8 @@ name const & get_iff_symm_name(); name const & get_iff_trans_name(); name const & get_iff_mp_name(); name const & get_iff_mpr_name(); +name const & get_iff_elim_left_name(); +name const & get_iff_elim_right_name(); name const & get_iff_false_intro_name(); name const & get_iff_true_intro_name(); name const & get_implies_name(); @@ -65,10 +68,12 @@ name const & get_ite_name(); name const & get_lift_name(); name const & get_lift_down_name(); name const & get_lift_up_name(); +name const & get_mul_name(); name const & get_nat_name(); name const & get_nat_of_num_name(); name const & get_nat_succ_name(); name const & get_nat_zero_name(); +name const & get_neg_name(); name const & get_not_name(); name const & get_num_name(); name const & get_num_zero_name(); diff --git a/src/library/constants.txt b/src/library/constants.txt index 1569f6e92..05c6a027d 100644 --- a/src/library/constants.txt +++ b/src/library/constants.txt @@ -1,4 +1,5 @@ absurd +add and and.elim_left and.elim_right @@ -48,6 +49,8 @@ iff.symm iff.trans iff.mp iff.mpr +iff.elim_left +iff.elim_right iff_false_intro iff_true_intro implies @@ -58,10 +61,12 @@ ite lift lift.down lift.up +mul nat nat.of_num nat.succ nat.zero +neg not num num.zero