From 2e626b29fbb077d31db5b44c96239f2c8bffc2c8 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 6 Feb 2015 11:03:36 -0800 Subject: [PATCH] feat(library/tactic/rewrite_tactic): allow many constants to be provided in a single rewrite unfold step --- src/frontends/lean/parse_rewrite_tactic.cpp | 16 +++++++++-- src/library/tactic/rewrite_tactic.cpp | 32 +++++++++++---------- src/library/tactic/rewrite_tactic.h | 2 +- tests/lean/run/rewriter12.lean | 9 ++++++ 4 files changed, 41 insertions(+), 18 deletions(-) create mode 100644 tests/lean/run/rewriter12.lean diff --git a/src/frontends/lean/parse_rewrite_tactic.cpp b/src/frontends/lean/parse_rewrite_tactic.cpp index 88cd87152..b29ef8851 100644 --- a/src/frontends/lean/parse_rewrite_tactic.cpp +++ b/src/frontends/lean/parse_rewrite_tactic.cpp @@ -35,9 +35,21 @@ static expr parse_rule(parser & p) { expr parse_rewrite_element(parser & p) { if (p.curr_is_token(get_up_tk()) || p.curr_is_token(get_caret_tk())) { p.next(); - name n = p.check_constant_next("invalid unfold rewrite step, constant expected"); + buffer to_unfold; + if (p.curr_is_token(get_lcurly_tk())) { + p.next(); + while (true) { + to_unfold.push_back(p.check_constant_next("invalid unfold rewrite step, identifier expected")); + if (!p.curr_is_token(get_comma_tk())) + break; + p.next(); + } + p.check_token_next(get_rcurly_tk(), "invalid unfold rewrite step, ',' or '}' expected"); + } else { + to_unfold.push_back(p.check_constant_next("invalid unfold rewrite step, identifier or '{' expected")); + } location loc = parse_tactic_location(p); - return mk_rewrite_unfold(n, loc); + return mk_rewrite_unfold(to_list(to_unfold), loc); } bool symm = false; if (p.curr_is_token(get_sub_tk())) { diff --git a/src/library/tactic/rewrite_tactic.cpp b/src/library/tactic/rewrite_tactic.cpp index 9dcffde40..60af42075 100644 --- a/src/library/tactic/rewrite_tactic.cpp +++ b/src/library/tactic/rewrite_tactic.cpp @@ -42,19 +42,21 @@ unsigned get_rewriter_max_iterations(options const & opts) { } class unfold_info { - name m_name; - location m_location; + list m_names; + location m_location; public: unfold_info() {} - unfold_info(name const & n, location const & loc):m_name(n), m_location(loc) {} - name const & get_name() const { return m_name; } + unfold_info(list const & l, location const & loc):m_names(l), m_location(loc) {} + list const & get_names() const { return m_names; } location const & get_location() const { return m_location; } friend serializer & operator<<(serializer & s, unfold_info const & e) { - s << e.m_name << e.m_location; + write_list(s, e.m_names); + s << e.m_location; return s; } friend deserializer & operator>>(deserializer & d, unfold_info & e) { - d >> e.m_name >> e.m_location; + e.m_names = read_list(d); + d >> e.m_location; return d; } }; @@ -204,8 +206,8 @@ public: unfold_info const & get_info() const { return m_info; } }; -expr mk_rewrite_unfold(name const & n, location const & loc) { - macro_definition def(new rewrite_unfold_macro_cell(unfold_info(n, loc))); +expr mk_rewrite_unfold(list const & ns, location const & loc) { + macro_definition def(new rewrite_unfold_macro_cell(unfold_info(ns, loc))); return mk_macro(def); } @@ -435,11 +437,11 @@ class rewrite_fn { return m_g.mk_meta(m_ngen.next(), type); } - optional reduce(expr const & e, optional const & to_unfold) { + optional reduce(expr const & e, list const & to_unfold) { bool unfolded = !to_unfold; extra_opaque_pred pred([&](name const & n) { // everything is opaque but to_unfold - if (to_unfold && *to_unfold == n) { + if (std::find(to_unfold.begin(), to_unfold.end(), n) != to_unfold.end()) { unfolded = true; return false; } else { @@ -467,7 +469,7 @@ class rewrite_fn { update_goal(new_g); } - bool process_reduce_goal(optional const & to_unfold) { + bool process_reduce_goal(list const & to_unfold) { if (auto new_type = reduce(m_g.get_type(), to_unfold)) { replace_goal(*new_type); return true; @@ -496,7 +498,7 @@ class rewrite_fn { update_goal(new_g); } - bool process_reduce_hypothesis(expr const & hyp, optional const & to_unfold) { + bool process_reduce_hypothesis(expr const & hyp, list const & to_unfold) { if (auto new_hyp_type = reduce(mlocal_type(hyp), to_unfold)) { replace_hypothesis(hyp, *new_hyp_type); return true; @@ -505,7 +507,7 @@ class rewrite_fn { } } - bool process_reduce_step(optional const & to_unfold, location const & loc) { + bool process_reduce_step(list const & to_unfold, location const & loc) { if (loc.is_goal_only()) return process_reduce_goal(to_unfold); bool progress = false; @@ -527,7 +529,7 @@ class rewrite_fn { bool process_unfold_step(expr const & elem) { lean_assert(is_rewrite_unfold_step(elem)); auto info = get_rewrite_unfold_info(elem); - return process_reduce_step(optional(info.get_name()), info.get_location()); + return process_reduce_step(info.get_names(), info.get_location()); } optional unify_with(expr const & t, expr const & e) { @@ -594,7 +596,7 @@ class rewrite_fn { lean_assert(is_rewrite_reduce_step(elem)); if (macro_num_args(elem) == 0) { auto info = get_rewrite_reduce_info(elem); - return process_reduce_step(optional(), info.get_location()); + return process_reduce_step(list(), info.get_location()); } else { auto info = get_rewrite_reduce_info(elem); return process_reduce_to_step(macro_arg(elem, 0), info.get_location()); diff --git a/src/library/tactic/rewrite_tactic.h b/src/library/tactic/rewrite_tactic.h index 1014f738d..08bcfbb10 100644 --- a/src/library/tactic/rewrite_tactic.h +++ b/src/library/tactic/rewrite_tactic.h @@ -9,7 +9,7 @@ Author: Leonardo de Moura #include "library/tactic/location.h" namespace lean { -expr mk_rewrite_unfold(name const & n, location const & loc); +expr mk_rewrite_unfold(list const & ns, location const & loc); expr mk_rewrite_reduce(location const & loc); expr mk_rewrite_reduce_to(expr const & e, location const & loc); expr mk_rewrite_once(optional const & pattern, expr const & H, bool symm, location const & loc); diff --git a/tests/lean/run/rewriter12.lean b/tests/lean/run/rewriter12.lean new file mode 100644 index 000000000..fb4955b56 --- /dev/null +++ b/tests/lean/run/rewriter12.lean @@ -0,0 +1,9 @@ +import data.nat +open nat +constant f : nat → nat + +example (x y : nat) (H1 : (λ z, z + 0) x = y) : f x = f y := +by rewrite [▸* at H1, ^{add, nat.rec_on, of_num} at H1, H1] + +example (x y : nat) (H1 : (λ z, z + 0) x = y) : f x = f y := +by rewrite [▸* at H1, ↑{add, nat.rec_on, of_num} at H1, H1]