diff --git a/src/kernel/instantiate.cpp b/src/kernel/instantiate.cpp index cb0a7ac8b..023dfac20 100644 --- a/src/kernel/instantiate.cpp +++ b/src/kernel/instantiate.cpp @@ -122,7 +122,7 @@ expr beta_reduce(expr t) { return none_expr(); }; while (true) { - expr new_t = replace_fn(f)(t); + expr new_t = replace(t, f); if (new_t == t) return new_t; else diff --git a/src/kernel/replace_fn.cpp b/src/kernel/replace_fn.cpp index 64815f26d..e898b0025 100644 --- a/src/kernel/replace_fn.cpp +++ b/src/kernel/replace_fn.cpp @@ -1,129 +1,206 @@ /* -Copyright (c) 2013 Microsoft Corporation. All rights reserved. +Copyright (c) 2013-2014 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ +#include #include "kernel/replace_fn.h" +#ifndef LEAN_DEFAULT_REPLACE_CACHE_CAPACITY +#define LEAN_DEFAULT_REPLACE_CACHE_CAPACITY 1024*32 +#endif + namespace lean { -void replace_fn::save_result(expr const & e, expr const & r, unsigned offset, bool shared) { - if (shared) - m_cache.insert(std::make_pair(expr_cell_offset(e.raw(), offset), r)); - m_post(e, r); - m_rs.push_back(r); -} +struct replace_cache { + struct entry { + expr_cell * m_cell; + unsigned m_offset; + expr m_result; + entry():m_cell(nullptr) {} + }; + unsigned m_capacity; + std::vector m_cache; + std::vector m_used; + replace_cache(unsigned c):m_capacity(c), m_cache(c) {} -/** - \brief Visit \c e at the given offset. Return true iff the result is on the - result stack \c m_rs. Return false iff a new frame was pushed on the stack \c m_fs. - The idea is that after the frame is processed, the result will be on the result stack. -*/ -bool replace_fn::visit(expr const & e, unsigned offset) { - bool shared = false; - if (is_shared(e)) { - expr_cell_offset p(e.raw(), offset); - auto it = m_cache.find(p); - if (it != m_cache.end()) { - m_rs.push_back(it->second); - return true; + expr * find(expr const & e, unsigned offset) { + unsigned i = hash(e.hash_alloc(), offset) % m_capacity; + if (m_cache[i].m_cell == e.raw()) + return &m_cache[i].m_result; + else + return nullptr; + } + + void insert(expr const & e, unsigned offset, expr const & v) { + unsigned i = hash(e.hash_alloc(), offset) % m_capacity; + if (m_cache[i].m_cell == nullptr) + m_used.push_back(i); + m_cache[i].m_cell = e.raw(); + m_cache[i].m_offset = offset; + m_cache[i].m_result = v; + } + + void clear() { + for (unsigned i : m_used) { + m_cache[i].m_cell = nullptr; + m_cache[i].m_result = expr(); } - shared = true; + m_used.clear(); } +}; - optional r = m_f(e, offset); - if (r) { - save_result(e, *r, offset, shared); - return true; - } else if (is_atomic(e)) { - save_result(e, e, offset, shared); - return true; - } else { - m_fs.emplace_back(e, offset, shared); - return false; - } -} +MK_THREAD_LOCAL_GET(replace_cache, get_replace_cache, LEAN_DEFAULT_REPLACE_CACHE_CAPACITY) + +struct replace_cache_reset { + ~replace_cache_reset() { get_replace_cache().clear(); } +}; /** - \brief Return true iff f.m_index == idx. - When the result is true, f.m_index is incremented. + \brief Functional for applying F to the subexpressions of a given expression. + + The signature of \c F is + expr const &, unsigned -> optional(expr) + + F is invoked for each subexpression \c s of the input expression e. + In a call F(s, n), n is the scope level, i.e., the number of + bindings operators that enclosing \c s. The replaces only visits children of \c e + if F return none_expr */ -bool replace_fn::check_index(frame & f, unsigned idx) { - if (f.m_index == idx) { - f.m_index++; - return true; - } else { - return false; +class replace_fn { + struct frame { + expr m_expr; + unsigned m_offset; + bool m_shared; + unsigned m_index; + frame(expr const & e, unsigned o, bool s):m_expr(e), m_offset(o), m_shared(s), m_index(0) {} + }; + typedef buffer frame_stack; + typedef buffer result_stack; + + std::function(expr const &, unsigned)> m_f; + frame_stack m_fs; + result_stack m_rs; + replace_cache & m_cache; + + void save_result(expr const & e, expr const & r, unsigned offset, bool shared) { + if (shared) + m_cache.insert(e, offset, r); + m_rs.push_back(r); } -} -expr const & replace_fn::rs(int i) { - lean_assert(i < 0); - return m_rs[m_rs.size() + i]; -} - -void replace_fn::pop_rs(unsigned num) { - m_rs.shrink(m_rs.size() - num); -} - -expr replace_fn::operator()(expr const & e) { - expr r; - visit(e, 0); - while (!m_fs.empty()) { - begin_loop: - check_interrupted(); - frame & f = m_fs.back(); - expr const & e = f.m_expr; - unsigned offset = f.m_offset; - switch (e.kind()) { - case expr_kind::Constant: case expr_kind::Sort: - case expr_kind::Var: - lean_unreachable(); // LCOV_EXCL_LINE - case expr_kind::Meta: case expr_kind::Local: - if (check_index(f, 0) && !visit(mlocal_type(e), offset)) - goto begin_loop; - r = update_mlocal(e, rs(-1)); - pop_rs(1); - break; - case expr_kind::App: - if (check_index(f, 0) && !visit(app_fn(e), offset)) - goto begin_loop; - if (check_index(f, 1) && !visit(app_arg(e), offset)) - goto begin_loop; - r = update_app(e, rs(-2), rs(-1)); - pop_rs(2); - break; - case expr_kind::Pi: case expr_kind::Lambda: - if (check_index(f, 0) && !visit(binding_domain(e), offset)) - goto begin_loop; - if (check_index(f, 1) && !visit(binding_body(e), offset + 1)) - goto begin_loop; - r = update_binding(e, rs(-2), rs(-1)); - pop_rs(2); - break; - case expr_kind::Macro: - while (f.m_index < macro_num_args(e)) { - unsigned idx = f.m_index; - f.m_index++; - if (!visit(macro_arg(e, idx), offset)) - goto begin_loop; + /** + \brief Visit \c e at the given offset. Return true iff the result is on the + result stack \c m_rs. Return false iff a new frame was pushed on the stack \c m_fs. + The idea is that after the frame is processed, the result will be on the result stack. + */ + bool visit(expr const & e, unsigned offset) { + bool shared = false; + if (is_shared(e)) { + if (auto r = m_cache.find(e, offset)) { + m_rs.push_back(*r); + return true; } - r = update_macro(e, macro_num_args(e), &rs(-macro_num_args(e))); - pop_rs(macro_num_args(e)); - break; + shared = true; } - save_result(e, r, offset, f.m_shared); - m_fs.pop_back(); - } - lean_assert(m_rs.size() == 1); - r = m_rs.back(); - m_rs.pop_back(); - return r; -} -void replace_fn::clear() { - m_cache.clear(); - m_fs.clear(); - m_rs.clear(); + optional r = m_f(e, offset); + if (r) { + save_result(e, *r, offset, shared); + return true; + } else if (is_atomic(e)) { + save_result(e, e, offset, shared); + return true; + } else { + m_fs.emplace_back(e, offset, shared); + return false; + } + } + + /** + \brief Return true iff f.m_index == idx. + When the result is true, f.m_index is incremented. + */ + bool check_index(frame & f, unsigned idx) { + if (f.m_index == idx) { + f.m_index++; + return true; + } else { + return false; + } + } + + expr const & rs(int i) { + lean_assert(i < 0); + return m_rs[m_rs.size() + i]; + } + + void pop_rs(unsigned num) { + m_rs.shrink(m_rs.size() - num); + } + +public: + template + replace_fn(F const & f):m_f(f), m_cache(get_replace_cache()) {} + + expr operator()(expr const & e) { + replace_cache_reset reset; + expr r; + visit(e, 0); + while (!m_fs.empty()) { + begin_loop: + check_interrupted(); + frame & f = m_fs.back(); + expr const & e = f.m_expr; + unsigned offset = f.m_offset; + switch (e.kind()) { + case expr_kind::Constant: case expr_kind::Sort: + case expr_kind::Var: + lean_unreachable(); // LCOV_EXCL_LINE + case expr_kind::Meta: case expr_kind::Local: + if (check_index(f, 0) && !visit(mlocal_type(e), offset)) + goto begin_loop; + r = update_mlocal(e, rs(-1)); + pop_rs(1); + break; + case expr_kind::App: + if (check_index(f, 0) && !visit(app_fn(e), offset)) + goto begin_loop; + if (check_index(f, 1) && !visit(app_arg(e), offset)) + goto begin_loop; + r = update_app(e, rs(-2), rs(-1)); + pop_rs(2); + break; + case expr_kind::Pi: case expr_kind::Lambda: + if (check_index(f, 0) && !visit(binding_domain(e), offset)) + goto begin_loop; + if (check_index(f, 1) && !visit(binding_body(e), offset + 1)) + goto begin_loop; + r = update_binding(e, rs(-2), rs(-1)); + pop_rs(2); + break; + case expr_kind::Macro: + while (f.m_index < macro_num_args(e)) { + unsigned idx = f.m_index; + f.m_index++; + if (!visit(macro_arg(e, idx), offset)) + goto begin_loop; + } + r = update_macro(e, macro_num_args(e), &rs(-macro_num_args(e))); + pop_rs(macro_num_args(e)); + break; + } + save_result(e, r, offset, f.m_shared); + m_fs.pop_back(); + } + lean_assert(m_rs.size() == 1); + r = m_rs.back(); + m_rs.pop_back(); + return r; + } +}; + +expr replace(expr const & e, std::function(expr const &, unsigned)> const & f) { + return replace_fn(f)(e); } } diff --git a/src/kernel/replace_fn.h b/src/kernel/replace_fn.h index bb6a2fd60..073485c7f 100644 --- a/src/kernel/replace_fn.h +++ b/src/kernel/replace_fn.h @@ -1,5 +1,5 @@ /* -Copyright (c) 2013 Microsoft Corporation. All rights reserved. +Copyright (c) 2013-2014 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura @@ -13,64 +13,12 @@ Author: Leonardo de Moura namespace lean { /** - \brief Default replace_fn postprocessor functional object. It is a - do-nothing object. -*/ -class default_replace_postprocessor { -public: - void operator()(expr const &, expr const &) {} -}; + \brief Apply f to the subexpressions of a given expression. -/** - \brief Functional for applying F to the subexpressions of a given expression. - - The signature of \c F is - expr const &, unsigned -> optional(expr) - - F is invoked for each subexpression \c s of the input expression e. - In a call F(s, n), n is the scope level, i.e., the number of + f is invoked for each subexpression \c s of the input expression e. + In a call f(s, n), n is the scope level, i.e., the number of bindings operators that enclosing \c s. The replaces only visits children of \c e - if F return none_expr - - P is a "post-processing" functional object that is applied to each - pair (old, new) + if f return none_expr. */ -class replace_fn { - struct frame { - expr m_expr; - unsigned m_offset; - bool m_shared; - unsigned m_index; - frame(expr const & e, unsigned o, bool s):m_expr(e), m_offset(o), m_shared(s), m_index(0) {} - }; - typedef buffer frame_stack; - typedef buffer result_stack; - - expr_cell_offset_map m_cache; - std::function(expr const &, unsigned)> m_f; - std::function m_post; - frame_stack m_fs; - result_stack m_rs; - - void save_result(expr const & e, expr const & r, unsigned offset, bool shared); - bool visit(expr const & e, unsigned offset); - bool check_index(frame & f, unsigned idx); - expr const & rs(int i); - void pop_rs(unsigned num); - -public: - template - replace_fn(F const & f, P const & p = P()): - m_f(f), m_post(p) {} - expr operator()(expr const & e); - void clear(); -}; - -template expr replace(expr const & e, F const & f) { - return replace_fn(f)(e); -} - -template expr replace(expr const & e, F const & f, P const & p) { - return replace_fn(f, p)(e); -} +expr replace(expr const & e, std::function(expr const &, unsigned)> const & f); } diff --git a/src/tests/kernel/replace.cpp b/src/tests/kernel/replace.cpp index eb119a98f..78b11ca76 100644 --- a/src/tests/kernel/replace.cpp +++ b/src/tests/kernel/replace.cpp @@ -56,63 +56,10 @@ public: } }; -static expr arg(expr n, unsigned i) { - buffer args; - while (is_app(n)) { - args.push_back(app_arg(n)); - n = app_fn(n); - } - args.push_back(n); - return args[args.size() - i - 1]; -} - -static void tst3() { - expr f = Const("f"); - expr c = Const("c"); - expr d = Const("d"); - expr A = Const("A"); - expr_map trace; - auto proc = [&](expr const & x, unsigned offset) -> optional { - if (is_var(x)) { - unsigned vidx = var_idx(x); - if (vidx == offset) - return some_expr(c); - else if (vidx > offset) - return some_expr(mk_var(vidx-1)); - else - return none_expr(); - } else { - return none_expr(); - } - }; - expr x = Local("x", A); - expr y = Local("y", A); - - replace_fn replacer(proc, tracer(trace)); - expr t = Fun({x, y}, f(x, f(f(f(x, x), f(y, d)), f(d, d)))); - expr b = binding_body(t); - expr r = replacer(b); - std::cout << r << "\n"; - lean_assert(r == Fun(y, f(c, f(f(f(c, c), f(y, d)), f(d, d))))); - for (auto p : trace) { - std::cout << p.first << " --> " << p.second << "\n"; - } - lean_assert(trace[c] == Var(1)); - std::cout << arg(arg(binding_body(r), 2), 2) << "\n"; - lean_assert(arg(arg(binding_body(r), 2), 2) == f(d, d)); - lean_assert(trace.find(arg(arg(binding_body(r), 2), 2)) == trace.end()); - lean_assert(trace.find(binding_body(r)) != trace.end()); - lean_assert(trace.find(arg(binding_body(r), 2)) != trace.end()); - lean_assert(trace.find(arg(arg(binding_body(r), 2), 1)) != trace.end()); - lean_assert(trace.find(arg(arg(arg(binding_body(r), 2), 1), 1)) != trace.end()); - lean_assert(trace.find(arg(arg(arg(binding_body(r), 2), 1), 2)) == trace.end()); -} - int main() { save_stack_info(); tst1(); tst2(); - tst3(); std::cout << "done" << "\n"; return has_violations() ? 1 : 0; }