feat(library/blast/simplifier): add eta-reduction to simplifier

This commit is contained in:
Leonardo de Moura 2015-12-06 20:41:25 -08:00
parent 7eb1525ba5
commit 1eb28b842e
2 changed files with 45 additions and 4 deletions

View file

@ -8,6 +8,7 @@ Author: Daniel Selsam
#include "kernel/expr_maps.h"
#include "kernel/instantiate.h"
#include "library/constants.h"
#include "library/normalize.h"
#include "library/expr_lt.h"
#include "library/class_instance_resolution.h"
#include "library/relation_manager.h"
@ -255,6 +256,11 @@ class simplifier {
result fuse(expr const & e);
expr_pair split_summand(expr const & e, expr const & f_mul, expr const & one);
/* Apply whnf and eta-reduction
\remark We want (Sum n (fun x, f x)) and (Sum n f) to be the same.
\remark We may want to switch to eta-expansion later (see paper: "The Virtues of Eta-expansion").
TODO(Daniel, Leo): should we add an option for disabling/enabling eta? */
expr whnf_eta(expr const & e);
public:
simplifier(name const & rel, expr_predicate const & simp_pred): m_rel(rel), m_simp_pred(simp_pred) { }
@ -362,6 +368,11 @@ result simplifier::finalize(result const & r) {
return result(r.get_new(), pf);
}
/* Whnf + Eta */
expr simplifier::whnf_eta(expr const & e) {
return try_eta(whnf(e));
}
/* Simplification */
result simplifier::simplify(expr const & e, simp_rule_sets const & srss) {
@ -396,9 +407,9 @@ result simplifier::simplify(expr const & e, bool is_root) {
result r(e);
if (m_top_down) r = join(r, rewrite(whnf(r.get_new())));
if (m_top_down) r = join(r, rewrite(whnf_eta(r.get_new())));
r.update(whnf(r.get_new()));
r.update(whnf_eta(r.get_new()));
switch (r.get_new().kind()) {
case expr_kind::Local:
@ -411,7 +422,7 @@ result simplifier::simplify(expr const & e, bool is_root) {
lean_unreachable();
case expr_kind::Macro:
if (m_expand_macros) {
if (auto m = m_tmp_tctx->expand_macro(e)) r = join(r, simplify(whnf(*m), is_root));
if (auto m = m_tmp_tctx->expand_macro(e)) r = join(r, simplify(whnf_eta(*m), is_root));
}
break;
case expr_kind::Lambda:
@ -425,7 +436,7 @@ result simplifier::simplify(expr const & e, bool is_root) {
break;
}
if (!m_top_down) r = join(r, rewrite(whnf(r.get_new())));
if (!m_top_down) r = join(r, rewrite(whnf_eta(r.get_new())));
if (r.get_new() == e && !using_eq()) {
result r_eq;

View file

@ -0,0 +1,30 @@
import data.nat
open - [rrs] nat
definition Sum : nat → (nat → nat) → nat :=
sorry
notation `Σ` binders ` < ` n `, ` r:(scoped f, Sum n f) := r
lemma Sum_const [simp] (n : nat) (c : nat) : (Σ x < n, c) = n * c :=
sorry
lemma Sum_add [simp] (f g : nat → nat) (n : nat) : (Σ x < n, f x + g x) = (Σ x < n, f x) + (Σ x < n, g x) :=
sorry
attribute add.assoc add.comm add.left_comm mul_one add_zero zero_add one_mul mul.comm mul.assoc mul.left_comm [simp]
example (f : nat → nat) (n : nat) : (Σ x < n, f x + 1) = (Σ x < n, f x) + n :=
by simp
example (f g h : nat → nat) (n : nat) : (Σ x < n, f x + g x + h x) = (Σ x < n, h x) + (Σ x < n, f x) + (Σ x < n, g x) :=
by simp
example (f g h : nat → nat) (n : nat) : (Σ x < n, f x + g x + h x) = Sum n h + (Σ x < n, f x) + (Σ x < n, g x) :=
by simp
example (f g h : nat → nat) (n : nat) : (Σ x < n, f x + g x + h x + 0) = Sum n h + (Σ x < n, f x) + (Σ x < n, g x) :=
by simp
example (f g h : nat → nat) (n : nat) : (Σ x < n, f x + g x + h x + 2) = 0 + Sum n h + (Σ x < n, f x) + (Σ x < n, g x) + 2 * n :=
by simp