refactor(kernel/replace_fn): use thread local cache

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-07-15 05:34:45 +01:00
parent bd0cc5c365
commit 999782d89d
4 changed files with 191 additions and 219 deletions

View file

@ -122,7 +122,7 @@ expr beta_reduce(expr t) {
return none_expr(); return none_expr();
}; };
while (true) { while (true) {
expr new_t = replace_fn(f)(t); expr new_t = replace(t, f);
if (new_t == t) if (new_t == t)
return new_t; return new_t;
else else

View file

@ -1,16 +1,91 @@
/* /*
Copyright (c) 2013 Microsoft Corporation. All rights reserved. Copyright (c) 2013-2014 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE. Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura Author: Leonardo de Moura
*/ */
#include <vector>
#include "kernel/replace_fn.h" #include "kernel/replace_fn.h"
#ifndef LEAN_DEFAULT_REPLACE_CACHE_CAPACITY
#define LEAN_DEFAULT_REPLACE_CACHE_CAPACITY 1024*32
#endif
namespace lean { namespace lean {
void replace_fn::save_result(expr const & e, expr const & r, unsigned offset, bool shared) { struct replace_cache {
struct entry {
expr_cell * m_cell;
unsigned m_offset;
expr m_result;
entry():m_cell(nullptr) {}
};
unsigned m_capacity;
std::vector<entry> m_cache;
std::vector<unsigned> m_used;
replace_cache(unsigned c):m_capacity(c), m_cache(c) {}
expr * find(expr const & e, unsigned offset) {
unsigned i = hash(e.hash_alloc(), offset) % m_capacity;
if (m_cache[i].m_cell == e.raw())
return &m_cache[i].m_result;
else
return nullptr;
}
void insert(expr const & e, unsigned offset, expr const & v) {
unsigned i = hash(e.hash_alloc(), offset) % m_capacity;
if (m_cache[i].m_cell == nullptr)
m_used.push_back(i);
m_cache[i].m_cell = e.raw();
m_cache[i].m_offset = offset;
m_cache[i].m_result = v;
}
void clear() {
for (unsigned i : m_used) {
m_cache[i].m_cell = nullptr;
m_cache[i].m_result = expr();
}
m_used.clear();
}
};
MK_THREAD_LOCAL_GET(replace_cache, get_replace_cache, LEAN_DEFAULT_REPLACE_CACHE_CAPACITY)
struct replace_cache_reset {
~replace_cache_reset() { get_replace_cache().clear(); }
};
/**
\brief Functional for applying <tt>F</tt> to the subexpressions of a given expression.
The signature of \c F is
expr const &, unsigned -> optional(expr)
F is invoked for each subexpression \c s of the input expression e.
In a call <tt>F(s, n)</tt>, n is the scope level, i.e., the number of
bindings operators that enclosing \c s. The replaces only visits children of \c e
if F return none_expr
*/
class replace_fn {
struct frame {
expr m_expr;
unsigned m_offset;
bool m_shared;
unsigned m_index;
frame(expr const & e, unsigned o, bool s):m_expr(e), m_offset(o), m_shared(s), m_index(0) {}
};
typedef buffer<frame> frame_stack;
typedef buffer<expr> result_stack;
std::function<optional<expr>(expr const &, unsigned)> m_f;
frame_stack m_fs;
result_stack m_rs;
replace_cache & m_cache;
void save_result(expr const & e, expr const & r, unsigned offset, bool shared) {
if (shared) if (shared)
m_cache.insert(std::make_pair(expr_cell_offset(e.raw(), offset), r)); m_cache.insert(e, offset, r);
m_post(e, r);
m_rs.push_back(r); m_rs.push_back(r);
} }
@ -19,13 +94,11 @@ void replace_fn::save_result(expr const & e, expr const & r, unsigned offset, bo
result stack \c m_rs. Return false iff a new frame was pushed on the stack \c m_fs. result stack \c m_rs. Return false iff a new frame was pushed on the stack \c m_fs.
The idea is that after the frame is processed, the result will be on the result stack. The idea is that after the frame is processed, the result will be on the result stack.
*/ */
bool replace_fn::visit(expr const & e, unsigned offset) { bool visit(expr const & e, unsigned offset) {
bool shared = false; bool shared = false;
if (is_shared(e)) { if (is_shared(e)) {
expr_cell_offset p(e.raw(), offset); if (auto r = m_cache.find(e, offset)) {
auto it = m_cache.find(p); m_rs.push_back(*r);
if (it != m_cache.end()) {
m_rs.push_back(it->second);
return true; return true;
} }
shared = true; shared = true;
@ -48,7 +121,7 @@ bool replace_fn::visit(expr const & e, unsigned offset) {
\brief Return true iff <tt>f.m_index == idx</tt>. \brief Return true iff <tt>f.m_index == idx</tt>.
When the result is true, <tt>f.m_index</tt> is incremented. When the result is true, <tt>f.m_index</tt> is incremented.
*/ */
bool replace_fn::check_index(frame & f, unsigned idx) { bool check_index(frame & f, unsigned idx) {
if (f.m_index == idx) { if (f.m_index == idx) {
f.m_index++; f.m_index++;
return true; return true;
@ -57,16 +130,21 @@ bool replace_fn::check_index(frame & f, unsigned idx) {
} }
} }
expr const & replace_fn::rs(int i) { expr const & rs(int i) {
lean_assert(i < 0); lean_assert(i < 0);
return m_rs[m_rs.size() + i]; return m_rs[m_rs.size() + i];
} }
void replace_fn::pop_rs(unsigned num) { void pop_rs(unsigned num) {
m_rs.shrink(m_rs.size() - num); m_rs.shrink(m_rs.size() - num);
} }
expr replace_fn::operator()(expr const & e) { public:
template<typename F>
replace_fn(F const & f):m_f(f), m_cache(get_replace_cache()) {}
expr operator()(expr const & e) {
replace_cache_reset reset;
expr r; expr r;
visit(e, 0); visit(e, 0);
while (!m_fs.empty()) { while (!m_fs.empty()) {
@ -120,10 +198,9 @@ expr replace_fn::operator()(expr const & e) {
m_rs.pop_back(); m_rs.pop_back();
return r; return r;
} }
};
void replace_fn::clear() { expr replace(expr const & e, std::function<optional<expr>(expr const &, unsigned)> const & f) {
m_cache.clear(); return replace_fn(f)(e);
m_fs.clear();
m_rs.clear();
} }
} }

View file

@ -1,5 +1,5 @@
/* /*
Copyright (c) 2013 Microsoft Corporation. All rights reserved. Copyright (c) 2013-2014 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE. Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura Author: Leonardo de Moura
@ -13,64 +13,12 @@ Author: Leonardo de Moura
namespace lean { namespace lean {
/** /**
\brief Default replace_fn postprocessor functional object. It is a \brief Apply <tt>f</tt> to the subexpressions of a given expression.
do-nothing object.
*/
class default_replace_postprocessor {
public:
void operator()(expr const &, expr const &) {}
};
/** f is invoked for each subexpression \c s of the input expression e.
\brief Functional for applying <tt>F</tt> to the subexpressions of a given expression. In a call <tt>f(s, n)</tt>, n is the scope level, i.e., the number of
The signature of \c F is
expr const &, unsigned -> optional(expr)
F is invoked for each subexpression \c s of the input expression e.
In a call <tt>F(s, n)</tt>, n is the scope level, i.e., the number of
bindings operators that enclosing \c s. The replaces only visits children of \c e bindings operators that enclosing \c s. The replaces only visits children of \c e
if F return none_expr if f return none_expr.
P is a "post-processing" functional object that is applied to each
pair (old, new)
*/ */
class replace_fn { expr replace(expr const & e, std::function<optional<expr>(expr const &, unsigned)> const & f);
struct frame {
expr m_expr;
unsigned m_offset;
bool m_shared;
unsigned m_index;
frame(expr const & e, unsigned o, bool s):m_expr(e), m_offset(o), m_shared(s), m_index(0) {}
};
typedef buffer<frame> frame_stack;
typedef buffer<expr> result_stack;
expr_cell_offset_map<expr> m_cache;
std::function<optional<expr>(expr const &, unsigned)> m_f;
std::function<void(expr const &, expr const &)> m_post;
frame_stack m_fs;
result_stack m_rs;
void save_result(expr const & e, expr const & r, unsigned offset, bool shared);
bool visit(expr const & e, unsigned offset);
bool check_index(frame & f, unsigned idx);
expr const & rs(int i);
void pop_rs(unsigned num);
public:
template<typename F, typename P = default_replace_postprocessor>
replace_fn(F const & f, P const & p = P()):
m_f(f), m_post(p) {}
expr operator()(expr const & e);
void clear();
};
template<typename F> expr replace(expr const & e, F const & f) {
return replace_fn(f)(e);
}
template<typename F, typename P> expr replace(expr const & e, F const & f, P const & p) {
return replace_fn(f, p)(e);
}
} }

View file

@ -56,63 +56,10 @@ public:
} }
}; };
static expr arg(expr n, unsigned i) {
buffer<expr> args;
while (is_app(n)) {
args.push_back(app_arg(n));
n = app_fn(n);
}
args.push_back(n);
return args[args.size() - i - 1];
}
static void tst3() {
expr f = Const("f");
expr c = Const("c");
expr d = Const("d");
expr A = Const("A");
expr_map<expr> trace;
auto proc = [&](expr const & x, unsigned offset) -> optional<expr> {
if (is_var(x)) {
unsigned vidx = var_idx(x);
if (vidx == offset)
return some_expr(c);
else if (vidx > offset)
return some_expr(mk_var(vidx-1));
else
return none_expr();
} else {
return none_expr();
}
};
expr x = Local("x", A);
expr y = Local("y", A);
replace_fn replacer(proc, tracer(trace));
expr t = Fun({x, y}, f(x, f(f(f(x, x), f(y, d)), f(d, d))));
expr b = binding_body(t);
expr r = replacer(b);
std::cout << r << "\n";
lean_assert(r == Fun(y, f(c, f(f(f(c, c), f(y, d)), f(d, d)))));
for (auto p : trace) {
std::cout << p.first << " --> " << p.second << "\n";
}
lean_assert(trace[c] == Var(1));
std::cout << arg(arg(binding_body(r), 2), 2) << "\n";
lean_assert(arg(arg(binding_body(r), 2), 2) == f(d, d));
lean_assert(trace.find(arg(arg(binding_body(r), 2), 2)) == trace.end());
lean_assert(trace.find(binding_body(r)) != trace.end());
lean_assert(trace.find(arg(binding_body(r), 2)) != trace.end());
lean_assert(trace.find(arg(arg(binding_body(r), 2), 1)) != trace.end());
lean_assert(trace.find(arg(arg(arg(binding_body(r), 2), 1), 1)) != trace.end());
lean_assert(trace.find(arg(arg(arg(binding_body(r), 2), 1), 2)) == trace.end());
}
int main() { int main() {
save_stack_info(); save_stack_info();
tst1(); tst1();
tst2(); tst2();
tst3();
std::cout << "done" << "\n"; std::cout << "done" << "\n";
return has_violations() ? 1 : 0; return has_violations() ? 1 : 0;
} }