From 14c72e82f6d3054b6581286cf90ff7d941072907 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 4 Feb 2015 20:04:19 -0800 Subject: [PATCH] feat(library/tactic/rewrite_tactic): add support for rewriting hypotheses --- src/library/tactic/rewrite_tactic.cpp | 129 ++++++++++++++++++-------- tests/lean/run/rewriter3.lean | 23 +++++ 2 files changed, 113 insertions(+), 39 deletions(-) create mode 100644 tests/lean/run/rewriter3.lean diff --git a/src/library/tactic/rewrite_tactic.cpp b/src/library/tactic/rewrite_tactic.cpp index 9cc035fb1..8a39c65c4 100644 --- a/src/library/tactic/rewrite_tactic.cpp +++ b/src/library/tactic/rewrite_tactic.cpp @@ -10,6 +10,7 @@ Author: Leonardo de Moura #include "util/list_fn.h" #include "util/sexpr/option_declarations.h" #include "kernel/instantiate.h" +#include "kernel/abstract.h" #include "kernel/replace_fn.h" #include "kernel/for_each_fn.h" #include "kernel/inductive/inductive.h" @@ -518,8 +519,10 @@ class rewrite_fn { // rule, new_t typedef optional> unify_result; - // When successful, the result is the pair (H, new_t) where (H : new_t = t) - unify_result unify_target(expr const & t, expr const & pre_elem) { + // When successful, the result is the pair (H, new_t) where + // (H : new_t = t) if is_goal == true + // (H : t = new_t) if is_goal == false + unify_result unify_target(expr const & t, expr const & pre_elem, bool is_goal) { try { expr rule = get_rewrite_rule(pre_elem); auto rcs = m_elab(m_g, m_ngen.mk_child(), rule, false); @@ -565,25 +568,37 @@ class rewrite_fn { m_subst = new_subst; expr lhs = app_arg(app_fn(rule_type)); expr rhs = app_arg(rule_type); - if (symm) { - return unify_result(rule, lhs); + if (is_goal) { + if (symm) { + return unify_result(rule, lhs); + } else { + rule = mk_symm(*m_tc, rule); + return unify_result(rule, rhs); + } } else { - rule = mk_symm(*m_tc, rule); - return unify_result(rule, rhs); + if (symm) { + rule = mk_symm(*m_tc, rule); + return unify_result(rule, lhs); + } else { + return unify_result(rule, rhs); + } } } } catch (exception&) {} return unify_result(); } - // target, new_target, H : represents the rewrite H : target = new_target + // target, new_target, H : represents the rewrite (H : target = new_target) for hypothesis and (H : new_target = target) for goals typedef optional> find_result; // Search for \c pattern in \c e. If \c t is a match, then try to unify the type of the rule // in the rewrite step \c pre_elem with \c t. // When successful, this method returns the target \c t, the fully elaborated rule \c r, // and the new value \c new_t (i.e., the expression that will replace \c t). - find_result find_target(expr const & e, expr const & pattern, expr const & pre_elem) { + // + // \remark is_goal == true if \c e is the type of a goal. Otherwise, it is assumed to be the type + // of a hypothesis. This flag affects the equality proof built by this method. + find_result find_target(expr const & e, expr const & pattern, expr const & pre_elem, bool is_goal) { find_result result; for_each(e, [&](expr const & t, unsigned) { if (result) @@ -595,7 +610,7 @@ class rewrite_fn { if (assigned) reset_subst(); if (r) { - if (auto p = unify_target(t, pre_elem)) { + if (auto p = unify_target(t, pre_elem, is_goal)) { result = std::make_tuple(t, p->second, p->first); return false; } @@ -606,45 +621,81 @@ class rewrite_fn { return result; } - /** Given (a, b, P[a], Heq : b = a, occ), return (P[b], M : P[b], H : P[a]) - where M is a metavariable application of a fresh metavariable, and H is a witness (based on M) for P[a]. + bool process_rewrite_hypothesis(expr const & hyp, expr const & pre_elem, expr const & pattern, occurrence const & occ) { + expr Pa = mlocal_type(hyp); + bool is_goal = false; + if (auto it = find_target(Pa, pattern, pre_elem, is_goal)) { + expr a, Heq, b; // Heq is a proof of a = b + std::tie(a, b, Heq) = *it; - \remark occ is used to select which occurrences of a in P[a] will be replaced with b + bool has_dep_elim = inductive::has_dep_elim(m_env, get_eq_name()); + unsigned vidx = has_dep_elim ? 1 : 0; + expr Px = replace_occurrences(Pa, a, occ, vidx); + expr Pb = instantiate(Px, vidx, b); - \remark the witness \c H is used using eq.rec - */ - std::tuple apply_rewrite(expr const & a, expr const & b, expr const & Pa, expr const & Heq, occurrence const & occ) { - bool has_dep_elim = inductive::has_dep_elim(m_env, get_eq_name()); - unsigned vidx = has_dep_elim ? 1 : 0; - expr Px = replace_occurrences(Pa, a, occ, vidx); - expr Pb = instantiate(Px, vidx, b); - expr A = m_tc->infer(a).first; - level l1 = sort_level(m_tc->ensure_type(Pa).first); - level l2 = sort_level(m_tc->ensure_type(A).first); - expr M = m_g.mk_meta(m_ngen.next(), Pb); - expr H; - if (has_dep_elim) { - expr Haeqx = mk_app(mk_constant(get_eq_name(), {l1}), A, b, mk_var(0)); - expr P = mk_lambda("x", A, mk_lambda("H", Haeqx, Px)); - H = mk_app({mk_constant(get_eq_rec_name(), {l1, l2}), A, b, P, M, a, Heq}); - } else { - H = mk_app({mk_constant(get_eq_rec_name(), {l1, l2}), A, b, mk_lambda("x", A, Px), M, a, Heq}); + expr A = m_tc->infer(a).first; + level l1 = sort_level(m_tc->ensure_type(Pa).first); + level l2 = sort_level(m_tc->ensure_type(A).first); + expr H; + if (has_dep_elim) { + expr Haeqx = mk_app(mk_constant(get_eq_name(), {l1}), A, b, mk_var(0)); + expr P = mk_lambda("x", A, mk_lambda("H", Haeqx, Px)); + H = mk_app({mk_constant(get_eq_rec_name(), {l1, l2}), A, a, P, hyp, b, Heq}); + } else { + H = mk_app({mk_constant(get_eq_rec_name(), {l1, l2}), A, a, mk_lambda("x", A, Px), hyp, b, Heq}); + } + + expr new_hyp = update_mlocal(hyp, Pb); + buffer new_hyps; + buffer args; + m_g.get_hyps(new_hyps); + for (expr & h : new_hyps) { + if (mlocal_name(h) == mlocal_name(hyp)) { + h = new_hyp; + args.push_back(H); + break; + } else { + args.push_back(h); + } + } + expr new_type = m_g.get_type(); + expr new_mvar = mk_metavar(m_ngen.next(), Pi(new_hyps, new_type)); + expr new_meta = mk_app(new_mvar, new_hyps); + goal new_g(new_meta, new_type); + expr val = m_g.abstract(mk_app(new_mvar, args)); + m_subst.assign(m_g.get_name(), val); + update_goal(new_g); + return true; } - return std::make_tuple(Pb, M, H); - } - - bool process_rewrite_hypothesis(expr const & /*hyp*/, expr const & /*pre_elem*/, expr const & /*pattern*/, occurrence const & /*occ*/) { - // TODO(Leo) return false; } bool process_rewrite_goal(expr const & pre_elem, expr const & pattern, occurrence const & occ) { - expr Pa = m_g.get_type(); - if (auto it = find_target(Pa, pattern, pre_elem)) { + expr Pa = m_g.get_type(); + bool is_goal = true; + if (auto it = find_target(Pa, pattern, pre_elem, is_goal)) { expr a, Heq, b; std::tie(a, b, Heq) = *it; - expr Pb, M, H; - std::tie(Pb, M, H) = apply_rewrite(a, b, Pa, Heq, occ); + + // Given (a, b, P[a], Heq : b = a, occ), return (P[b], M : P[b], H : P[a]) + // where M is a metavariable application of a fresh metavariable, and H is a witness (based on M) for P[a]. + bool has_dep_elim = inductive::has_dep_elim(m_env, get_eq_name()); + unsigned vidx = has_dep_elim ? 1 : 0; + expr Px = replace_occurrences(Pa, a, occ, vidx); + expr Pb = instantiate(Px, vidx, b); + expr A = m_tc->infer(a).first; + level l1 = sort_level(m_tc->ensure_type(Pa).first); + level l2 = sort_level(m_tc->ensure_type(A).first); + expr M = m_g.mk_meta(m_ngen.next(), Pb); + expr H; + if (has_dep_elim) { + expr Haeqx = mk_app(mk_constant(get_eq_name(), {l1}), A, b, mk_var(0)); + expr P = mk_lambda("x", A, mk_lambda("H", Haeqx, Px)); + H = mk_app({mk_constant(get_eq_rec_name(), {l1, l2}), A, b, P, M, a, Heq}); + } else { + H = mk_app({mk_constant(get_eq_rec_name(), {l1, l2}), A, b, mk_lambda("x", A, Px), M, a, Heq}); + } + goal new_g(M, Pb); expr val = m_g.abstract(H); m_subst.assign(m_g.get_name(), val); diff --git a/tests/lean/run/rewriter3.lean b/tests/lean/run/rewriter3.lean new file mode 100644 index 000000000..1b253d8c0 --- /dev/null +++ b/tests/lean/run/rewriter3.lean @@ -0,0 +1,23 @@ +import data.nat +open algebra + +constant f {A : Type} : A → A → A + +theorem test1 {A : Type} [s : comm_ring A] (a b c : A) (H : a + 0 = 0) : f a a = f 0 0 := +begin + rewrite add_zero at H, + rewrite H +end + +theorem test2 {A : Type} [s : comm_ring A] (a b c : A) (H : a + 0 = 0) : f a a = f 0 0 := +begin + rewrite add_zero at *, + rewrite H +end + +theorem test3 {A : Type} [s : comm_ring A] (a b c : A) (H : a + 0 = 0 + 0) : f a a = f 0 0 := +begin + rewrite add_zero at H, + rewrite zero_add at H, + rewrite H +end