From 7a4eb4b8ed701e8c80a203fb93e772511c283cf3 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 24 Jan 2014 22:32:55 -0800 Subject: [PATCH] feat(library/simplifier): contextual simplification for A -> B Signed-off-by: Leonardo de Moura --- src/library/simplifier/simplifier.cpp | 69 +++++++++++++++++++++++---- tests/lean/simp24.lean | 36 ++++++++++++++ tests/lean/simp24.lean.expected.out | 12 +++++ 3 files changed, 108 insertions(+), 9 deletions(-) create mode 100644 tests/lean/simp24.lean create mode 100644 tests/lean/simp24.lean.expected.out diff --git a/src/library/simplifier/simplifier.cpp b/src/library/simplifier/simplifier.cpp index 83efd039f..8e29a6c0d 100644 --- a/src/library/simplifier/simplifier.cpp +++ b/src/library/simplifier/simplifier.cpp @@ -100,6 +100,7 @@ unsigned get_simplifier_max_steps(options const & opts) { return opts.get_unsign static name g_local("local"); static name g_C("C"); +static name g_H("H"); static name g_x("x"); static name g_unique = name::mk_internal_unique_name(); @@ -108,6 +109,7 @@ class simplifier_fn { expr m_out; // the result of a simplification step optional m_proof; // a proof that the result is equal to the input (when m_proofs_enabled) bool m_heq_proof; // true if the proof is for heterogeneous equality + result() {} explicit result(expr const & out, bool heq_proof = false): m_out(out), m_heq_proof(heq_proof) {} result(expr const & out, expr const & pr, bool heq_proof = false): @@ -920,22 +922,71 @@ class simplifier_fn { result simplify_pi(expr const & e) { lean_assert(is_pi(e)); - // TODO(Leo): handle implication, i.e., e is_proposition and is_arrow - if (m_has_heq) { + expr const & d = abst_domain(e); + expr b = abst_body(e); + bool is_prop = is_proposition(e); + bool is_d_prop = is_proposition(d); + bool is_arr = is_arrow(e); + if (is_d_prop && is_arr) { + if (m_contextual) { + // Contextual simplification for A -> B + // Rewrite A to A' + // And rewrite B to B' using A' + result res_d = simplify(d); + ensure_homogeneous(d, res_d); + flet set_depth(m_contextual_depth, m_contextual_depth+1); + expr H_proof = mk_constant(name(g_unique, m_contextual_depth)); + result res_b; + { + updt_rule_set update(m_rule_sets[0], res_d.m_out, H_proof); + freset m_reset_cache(m_cache); // must reset cache for the recursive call because we updated the rule_sets + set_context set(*this, extend(m_ctx, abst_name(e), res_d.m_out)); + res_b = simplify(b); + } + ensure_homogeneous(b, res_b); + if (is_eqp(res_d.m_out, d) && is_eqp(res_b.m_out, b)) + return rewrite(e, result(e)); + expr out = update_pi(e, res_d.m_out, res_b.m_out); + if (!m_proofs_enabled) + return rewrite(e, result(out)); + name C_name(g_C, m_contextual_depth); // H_name is a cryptic unique name + expr proof = mk_imp_congr_th(d, lower_free_vars(b, 1, 1), + res_d.m_out, lower_free_vars(res_b.m_out, 1, 1), + get_proof(res_d), mk_lambda(C_name, res_d.m_out, abstract(get_proof(res_b), H_proof))); + return rewrite(e, result(out, proof)); + } else { + // Simplify A -> B (when m_contextual == false) + result res_d = simplify(d); + ensure_homogeneous(d, res_d); + set_context set(*this, extend(m_ctx, abst_name(e), res_d.m_out)); + result res_b = simplify(b); + ensure_homogeneous(b, res_b); + if (is_eqp(res_d.m_out, d) && is_eqp(res_b.m_out, b)) + return rewrite(e, result(e)); + expr out = update_pi(e, res_d.m_out, res_b.m_out); + if (!m_proofs_enabled) + return rewrite(e, result(out)); + expr proof = mk_imp_congr_th(d, lower_free_vars(b, 1, 1), + res_d.m_out, lower_free_vars(res_b.m_out, 1, 1), + get_proof(res_d), mk_lambda(g_H, res_d.m_out, get_proof(res_b))); + return rewrite(e, result(out, proof)); + } + } else if (m_has_heq) { // TODO(Leo) return result(e); - } else if (is_proposition(e)) { - set_context set(*this, extend(m_ctx, abst_name(e), abst_domain(e))); - result res_body = simplify(abst_body(e)); + } else if (is_prop) { + // Simplify (forall x : A, P x) + set_context set(*this, extend(m_ctx, abst_name(e), d)); + result res_body = simplify(b); lean_assert(!res_body.m_heq_proof); expr new_body = res_body.m_out; - if (is_eqp(new_body, abst_body(e))) + if (is_eqp(new_body, b)) return rewrite(e, result(e)); - expr out = mk_pi(abst_name(e), abst_domain(e), new_body); + expr out = mk_pi(abst_name(e), d, new_body); if (!m_proofs_enabled || !res_body.m_proof) return rewrite(e, result(out)); - expr pr = mk_allext_th(abst_domain(e), - mk_lambda(e, abst_body(e)), + expr pr = mk_allext_th(d, + mk_lambda(e, b), mk_lambda(e, abst_body(out)), mk_lambda(e, *res_body.m_proof)); return rewrite(e, result(out, pr)); diff --git a/tests/lean/simp24.lean b/tests/lean/simp24.lean new file mode 100644 index 000000000..d74dbbd6a --- /dev/null +++ b/tests/lean/simp24.lean @@ -0,0 +1,36 @@ +rewrite_set simple +add_rewrite eq_id imp_truel imp_truer Nat::add_zeror : simple +variables a b : Nat +(* +local opts = options({"simplifier", "contextual"}, false) +local t = parse_lean('λ x, a = a → x = a') +local t2, pr = simplify(t, "simple", get_environment(), context(), opts) +print(t2) +print(pr) +get_environment():type_check(pr) +*) + +(* +local opts = options({"simplifier", "contextual"}, false) +local t = parse_lean('λ x, x = a → x = x') +local t2, pr = simplify(t, "simple", get_environment(), context(), opts) +print(t2) +print(pr) +get_environment():type_check(pr) +*) + +(* +local opts = options({"simplifier", "contextual"}, false) +local t = parse_lean('λ x, x = a + 0 → a = a') +local t2, pr = simplify(t, "simple", get_environment(), context(), opts) +print(t2) +print(pr) +*) + +(* +local t = parse_lean('λ x, a + 0 = 1 → x > a') +local t2, pr = simplify(t, "simple") +print(t2) +print(pr) +get_environment():type_check(pr) +*) diff --git a/tests/lean/simp24.lean.expected.out b/tests/lean/simp24.lean.expected.out new file mode 100644 index 000000000..80c1e3bdd --- /dev/null +++ b/tests/lean/simp24.lean.expected.out @@ -0,0 +1,12 @@ + Set: pp::colors + Set: pp::unicode + Assumed: a + Assumed: b +λ x : ℕ, x = a +funext (λ x : ℕ, trans (imp_congr (eq_id a) (λ H : ⊤, refl (x = a))) (imp_truel (x = a))) +λ x : ℕ, ⊤ +funext (λ x : ℕ, trans (imp_congr (refl (x = a)) (λ H : x = a, eq_id x)) (imp_truer (x = a))) +λ x : ℕ, ⊤ +funext (λ x : ℕ, trans (imp_congr (congr2 (eq x) (Nat::add_zeror a)) (λ H : x = a, eq_id a)) (imp_truer (x = a))) +λ x : ℕ, a = 1 → x > 1 +funext (λ x : ℕ, imp_congr (congr1 1 (congr2 eq (Nat::add_zeror a))) (λ C::1 : a = 1, congr2 (Nat::gt x) C::1))