diff --git a/src/kernel/replace.h b/src/kernel/replace.h index 1925c570a..1641f91a1 100644 --- a/src/kernel/replace.h +++ b/src/kernel/replace.h @@ -9,6 +9,15 @@ Author: Leonardo de Moura #include "expr.h" #include "expr_maps.h" namespace lean { +/** + \brief Default replace_fn postprocessor functional object. It is a + do-nothing object. +*/ +class default_replace_postprocessor { +public: + void operator()(expr const & old_e, expr const & new_e) {} +}; + /** \brief Functional for applying F to the subexpressions of a given expression. @@ -18,13 +27,17 @@ namespace lean { F is invoked for each subexpression \c s of the input expression e. In a call F(n, s), n is the scope level, i.e., the number of bindings operators that enclosing \c s. + + P is a "post-processing" functional object that is applied to each + pair (old, new) */ -template +template class replace_fn { static_assert(std::is_same::type, expr>::value, "replace_fn: return type of F is not expr"); expr_cell_offset_map m_cache; F m_f; + P m_post; expr apply(expr const & e, unsigned offset) { bool sh = false; @@ -60,12 +73,15 @@ class replace_fn { if (sh) m_cache.insert(std::make_pair(expr_cell_offset(e.raw(), offset), r)); + m_post(e, r); + return r; } public: - replace_fn(F const & f): - m_f(f) { + replace_fn(F const & f, P const & p = P()): + m_f(f), + m_post(p) { } expr operator()(expr const & e) { diff --git a/src/tests/kernel/replace.cpp b/src/tests/kernel/replace.cpp index 8dd1ad0f9..6389c6d39 100644 --- a/src/tests/kernel/replace.cpp +++ b/src/tests/kernel/replace.cpp @@ -8,6 +8,8 @@ Author: Leonardo de Moura #include "abstract.h" #include "instantiate.h" #include "deep_copy.h" +#include "expr_maps.h" +#include "replace.h" #include "printer.h" #include "name.h" #include "test.h" @@ -44,9 +46,63 @@ static void tst2() { lean_assert(instantiate(mk_pi("_", Var(3), Var(4)), Var(0)) == mk_pi("_", Var(2), Var(3))); } +class tracer { + expr_map & m_trace; +public: + tracer(expr_map & trace):m_trace(trace) {} + + void operator()(expr const & old_e, expr const & new_e) { + if (!is_eqp(new_e, old_e)) { + m_trace[new_e] = old_e; + } + } +}; + +static void tst3() { + expr f = Const("f"); + expr x = Const("x"); + expr y = Const("y"); + expr c = Const("c"); + expr d = Const("d"); + expr A = Const("A"); + expr_map trace; + auto proc = [&](expr const & x, unsigned offset) -> expr { + if (is_var(x)) { + unsigned vidx = var_idx(x); + if (vidx == offset) + return c; + else if (vidx > offset) + return mk_var(vidx-1); + else + return x; + } else { + return x; + } + }; + replace_fn replacer(proc, tracer(trace)); + expr t = Fun({{x, A}, {y, A}}, f(x, f(f(f(x,x), f(y, d)), f(d, d)))); + expr b = abst_body(t); + expr r = replacer(b); + std::cout << r << "\n"; + lean_assert(r == Fun({y, A}, f(c, f(f(f(c,c), f(y, d)), f(d, d))))); + for (auto p : trace) { + std::cout << p.first << " --> " << p.second << "\n"; + } + lean_assert(trace[c] == Var(1)); + std::cout << arg(arg(abst_body(r), 2), 2) << "\n"; + lean_assert(arg(arg(abst_body(r), 2), 2) == f(d,d)); + lean_assert(trace.find(arg(arg(abst_body(r), 2), 2)) == trace.end()); + lean_assert(trace.find(abst_body(r)) != trace.end()); + lean_assert(trace.find(arg(abst_body(r), 2)) != trace.end()); + lean_assert(trace.find(arg(arg(abst_body(r), 2), 1)) != trace.end()); + lean_assert(trace.find(arg(arg(arg(abst_body(r), 2), 1), 1)) != trace.end()); + lean_assert(trace.find(arg(arg(arg(abst_body(r), 2), 1), 2)) == trace.end()); +} + int main() { tst1(); tst2(); + tst3(); std::cout << "done" << "\n"; return has_violations() ? 1 : 0; }