feat(kernel/expr): add optional expression caching (aka "partial" hash-consing)

We do not enforce full hash-consing because we would need to synchronize
the access to the hashtable/cache.

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-06-03 13:32:02 -07:00
parent 45a3ab5141
commit 4a25e7442a
10 changed files with 100 additions and 25 deletions

View file

@ -12,6 +12,7 @@ option(BOOST "BOOST" OFF)
option(STATIC "STATIC" OFF)
option(SPLIT_STACK "SPLIT_STACK" OFF)
option(READLINE "READLINE" OFF)
option(CACHE_EXPRS "CACHE_EXPRS" ON)
# Added for CTest
include(CTest)
@ -40,6 +41,11 @@ else()
set(LEAN_EXTRA_CXX_FLAGS "${LEAN_EXTRA_CXX_FLAGS} -D LEAN_MULTI_THREAD")
endif()
if("${CACHE_EXPRS}" MATCHES "ON")
message(STATUS "Lean expression caching enabled (aka partial hashconsing)")
set(LEAN_EXTRA_CXX_FLAGS "${LEAN_EXTRA_CXX_FLAGS} -D LEAN_CACHE_EXPRS")
endif()
if("${STATIC}" MATCHES "ON")
set(LEAN_EXTRA_LINKER_FLAGS "${LEAN_EXTRA_LINKER_FLAGS} -static")
message(STATUS "Creating a static executable")

View file

@ -14,10 +14,15 @@ Author: Leonardo de Moura
#include "util/hash.h"
#include "util/buffer.h"
#include "util/object_serializer.h"
#include "util/lru_cache.h"
#include "kernel/expr.h"
#include "kernel/expr_eq_fn.h"
#include "kernel/free_vars.h"
#ifndef LEAN_INITIAL_EXPR_CACHE_CAPACITY
#define LEAN_INITIAL_EXPR_CACHE_CAPACITY 1024*16
#endif
namespace lean {
static expr g_dummy(mk_var(0));
expr::expr():expr(g_dummy) {}
@ -156,6 +161,7 @@ expr_binding::expr_binding(expr_kind k, name const & n, expr const & t, expr con
std::max(get_free_var_range(t), dec(get_free_var_range(b)))),
m_binder(n, t, i),
m_body(b) {
m_hash = ::lean::hash(m_hash, m_depth);
lean_assert(k == expr_kind::Lambda || k == expr_kind::Pi);
}
void expr_binding::dealloc(buffer<expr_cell*> & todelete) {
@ -257,6 +263,45 @@ expr_macro::~expr_macro() {
delete[] m_args;
}
// =======================================
// Constructors
#ifdef LEAN_CACHE_EXPRS
typedef lru_cache<expr, expr_hash, is_bi_equal_proc> expr_cache;
static expr_cache LEAN_THREAD_LOCAL g_expr_cache(LEAN_INITIAL_EXPR_CACHE_CAPACITY);
static bool LEAN_THREAD_LOCAL g_expr_cache_enabled = true;
inline expr cache(expr const & e) {
if (g_expr_cache_enabled) {
if (auto r = g_expr_cache.insert(e)) {
// std::cout << e << "\n===>\n" << *r << "\n";
return *r;
}
}
return e;
}
bool enable_expr_caching(bool f) {
bool r = g_expr_cache_enabled;
g_expr_cache_enabled = f;
return r;
}
#else
inline expr cache(expr && e) { return e; }
bool enable_expr_caching(bool) { return true; } // NOLINT
#endif
expr mk_var(unsigned idx) { return cache(expr(new expr_var(idx))); }
expr mk_constant(name const & n, levels const & ls) { return cache(expr(new expr_const(n, ls))); }
expr mk_macro(macro_definition const & m, unsigned num, expr const * args) { return cache(expr(new expr_macro(m, num, args))); }
expr mk_metavar(name const & n, expr const & t) { return cache(expr(new expr_mlocal(true, n, t))); }
expr mk_local(name const & n, name const & pp_n, expr const & t) { return cache(expr(new expr_local(n, pp_n, t))); }
expr mk_app(expr const & f, expr const & a) { return cache(expr(new expr_app(f, a))); }
expr mk_binding(expr_kind k, name const & n, expr const & t, expr const & e, binder_info const & i) {
return cache(expr(new expr_binding(k, n, t, e, i)));
}
expr mk_let(name const & n, expr const & t, expr const & v, expr const & e) { return cache(expr(new expr_let(n, t, v, e))); }
expr mk_sort(level const & l) { return cache(expr(new expr_sort(l))); }
// =======================================
void expr_cell::dealloc() {
try {
buffer<expr_cell*> todo;

View file

@ -156,6 +156,7 @@ bool operator==(expr const & a, expr const & b);
inline bool operator!=(expr const & a, expr const & b) { return !operator==(a, b); }
/** \brief Similar to ==, but it also compares binder information */
bool is_bi_equal(expr const & a, expr const & b);
struct is_bi_equal_proc { bool operator()(expr const & e1, expr const & e2) const { return is_bi_equal(e1, e2); } };
// =======================================
SPECIALIZE_OPTIONAL_FOR_SMART_PTR(expr)
@ -422,35 +423,33 @@ bool is_meta(expr const & e);
// =======================================
// Constructors
inline expr mk_var(unsigned idx) { return expr(new expr_var(idx)); }
expr mk_var(unsigned idx);
inline expr Var(unsigned idx) { return mk_var(idx); }
inline expr mk_constant(name const & n, levels const & ls) { return expr(new expr_const(n, ls)); }
expr mk_constant(name const & n, levels const & ls);
inline expr mk_constant(name const & n) { return mk_constant(n, levels()); }
inline expr Const(name const & n) { return mk_constant(n); }
inline expr mk_macro(macro_definition const & m, unsigned num = 0, expr const * args = nullptr) { return expr(new expr_macro(m, num, args)); }
inline expr mk_metavar(name const & n, expr const & t) { return expr(new expr_mlocal(true, n, t)); }
inline expr mk_local(name const & n, name const & pp_n, expr const & t) { return expr(new expr_local(n, pp_n, t)); }
inline expr mk_app(expr const & f, expr const & a) { return expr(new expr_app(f, a)); }
expr mk_app(expr const & f, unsigned num_args, expr const * args);
expr mk_app(unsigned num_args, expr const * args);
expr mk_macro(macro_definition const & m, unsigned num = 0, expr const * args = nullptr);
expr mk_metavar(name const & n, expr const & t);
expr mk_local(name const & n, name const & pp_n, expr const & t);
expr mk_app(expr const & f, expr const & a);
expr mk_app(expr const & f, unsigned num_args, expr const * args);
expr mk_app(unsigned num_args, expr const * args);
inline expr mk_app(std::initializer_list<expr> const & l) { return mk_app(l.size(), l.begin()); }
template<typename T> expr mk_app(T const & args) { return mk_app(args.size(), args.data()); }
template<typename T> expr mk_app(expr const & f, T const & args) { return mk_app(f, args.size(), args.data()); }
expr mk_rev_app(expr const & f, unsigned num_args, expr const * args);
expr mk_rev_app(unsigned num_args, expr const * args);
expr mk_rev_app(expr const & f, unsigned num_args, expr const * args);
expr mk_rev_app(unsigned num_args, expr const * args);
template<typename T> expr mk_rev_app(T const & args) { return mk_rev_app(args.size(), args.data()); }
template<typename T> expr mk_rev_app(expr const & f, T const & args) { return mk_rev_app(f, args.size(), args.data()); }
inline expr mk_binding(expr_kind k, name const & n, expr const & t, expr const & e, binder_info const & i = binder_info()) {
return expr(new expr_binding(k, n, t, e, i));
}
expr mk_binding(expr_kind k, name const & n, expr const & t, expr const & e, binder_info const & i = binder_info());
inline expr mk_lambda(name const & n, expr const & t, expr const & e, binder_info const & i = binder_info()) {
return mk_binding(expr_kind::Lambda, n, t, e, i);
}
inline expr mk_pi(name const & n, expr const & t, expr const & e, binder_info const & i = binder_info()) {
return mk_binding(expr_kind::Pi, n, t, e, i);
}
inline expr mk_let(name const & n, expr const & t, expr const & v, expr const & e) { return expr(new expr_let(n, t, v, e)); }
inline expr mk_sort(level const & l) { return expr(new expr_sort(l)); }
expr mk_let(name const & n, expr const & t, expr const & v, expr const & e);
expr mk_sort(level const & l);
/** \brief Return <tt>Pi(x.{sz-1}, domain[sz-1], ..., Pi(x.{0}, domain[0], range)...)</tt> */
expr mk_pi(unsigned sz, expr const * domain, expr const & range);
@ -495,6 +494,14 @@ inline expr expr::operator()(expr const & a1, expr const & a2, expr const & a3,
/** \brief Return application (...((f x_{n-1}) x_{n-2}) ... x_0) */
expr mk_app_vars(expr const & f, unsigned n);
bool enable_expr_caching(bool f);
/** \brief Helper class for temporarily enabling/disabling expression caching */
struct scoped_expr_caching {
bool m_old;
scoped_expr_caching(bool f) { m_old = enable_expr_caching(f); }
~scoped_expr_caching() { enable_expr_caching(m_old); }
};
// =======================================
// =======================================

View file

@ -27,7 +27,6 @@ using expr_cell_offset_map = typename std::unordered_map<expr_cell_offset, T, ex
template<typename T>
using expr_struct_map = typename std::unordered_map<expr, T, expr_hash, std::equal_to<expr>>;
// The following map also takes into account binder information
struct is_bi_equal_proc { bool operator()(expr const & e1, expr const & e2) const { return is_bi_equal(e1, e2); } };
template<typename T>
using expr_bi_struct_map = typename std::unordered_map<expr, T, expr_hash, is_bi_equal_proc>;
};

View file

@ -201,6 +201,7 @@ struct print_expr_fn {
print_expr_fn(std::ostream & out, bool type0_as_bool = true):m_out(out), m_type0_as_bool(type0_as_bool) {}
void operator()(expr const & e) {
scoped_expr_caching set(false);
print(e);
}
};

View file

@ -59,7 +59,10 @@ public:
\brief Return a new expression that is equal to the given
argument, but does not share any memory cell with it.
*/
expr operator()(expr const & a) { return apply(a); }
expr operator()(expr const & a) {
scoped_expr_caching set(false);
return apply(a);
}
};
expr deep_copy(expr const & e) { return deep_copy_fn()(e); }
}

View file

@ -769,6 +769,8 @@ static const struct luaL_Reg expr_m[] = {
{0, 0}
};
static int enable_expr_caching(lua_State * L) { return push_boolean(L, enable_expr_caching(lua_toboolean(L, 1))); }
static void expr_migrate(lua_State * src, int i, lua_State * tgt) {
push_expr(tgt, to_expr(src, i));
}
@ -800,6 +802,8 @@ static void open_expr(lua_State * L) {
SET_GLOBAL_FUN(expr_mk_local, "Local");
SET_GLOBAL_FUN(expr_pred, "is_expr");
SET_GLOBAL_FUN(enable_expr_caching, "enable_expr_caching");
push_expr(L, Bool);
lua_setglobal(L, "Bool");

View file

@ -44,7 +44,10 @@ static void tst1() {
std::cout << fa(a) << "\n";
lean_assert(is_eqp(app_fn(fa), f));
lean_assert(is_eqp(app_arg(fa), a));
lean_assert(!is_eqp(fa, f(a)));
{
scoped_expr_caching set(false);
lean_assert(!is_eqp(fa, f(a)));
}
lean_assert(fa(a) == f(a, a));
std::cout << fa(fa, fa) << "\n";
std::cout << mk_lambda("x", ty, Var(0)) << "\n";
@ -159,6 +162,7 @@ static void tst6() {
}
static void tst7() {
scoped_expr_caching set(false);
expr f = Const("f");
expr v = Var(0);
expr a1 = max_sharing(f(v, v));
@ -238,17 +242,20 @@ static void tst11() {
std::cout << abstract(mk_lambda("x", t, f(a, mk_lambda("y", t, f(b, a)))), Const("a")) << "\n";
lean_assert(abstract(mk_lambda("x", t, f(a, mk_lambda("y", t, f(b, a)))), Const("a")) ==
mk_lambda("x", t, f(Var(1), mk_lambda("y", t, f(b, Var(2))))));
std::cout << abstract_p(mk_lambda("x", t, f(a, mk_lambda("y", t, f(b, a)))), Const("a")) << "\n";
lean_assert(abstract_p(mk_lambda("x", t, f(a, mk_lambda("y", t, f(b, a)))), Const("a")) ==
mk_lambda("x", t, f(a, mk_lambda("y", t, f(b, a)))));
std::cout << abstract_p(mk_lambda("x", t, f(a, mk_lambda("y", t, f(b, a)))), a) << "\n";
lean_assert(abstract_p(mk_lambda("x", t, f(a, mk_lambda("y", t, f(b, a)))), a) ==
mk_lambda("x", t, f(Var(1), mk_lambda("y", t, f(b, Var(2))))));
{
scoped_expr_caching set(false);
std::cout << abstract_p(mk_lambda("x", t, f(a, mk_lambda("y", t, f(b, a)))), Const("a")) << "\n";
lean_assert(abstract_p(mk_lambda("x", t, f(a, mk_lambda("y", t, f(b, a)))), Const("a")) ==
mk_lambda("x", t, f(a, mk_lambda("y", t, f(b, a)))));
std::cout << abstract_p(mk_lambda("x", t, f(a, mk_lambda("y", t, f(b, a)))), a) << "\n";
lean_assert(abstract_p(mk_lambda("x", t, f(a, mk_lambda("y", t, f(b, a)))), a) ==
mk_lambda("x", t, f(Var(1), mk_lambda("y", t, f(b, Var(2))))));
}
lean_assert(substitute(f(f(f(a))), f(a), b) == f(f(b)));
}
static void tst12() {
scoped_expr_caching set(false);
expr f = Const("f");
expr v = Var(0);
expr a1 = max_sharing(f(v, v));
@ -305,6 +312,7 @@ static void tst15() {
}
static void check_copy(expr const & e) {
scoped_expr_caching set(false);
expr c = copy(e);
lean_assert(!is_eqp(e, c));
lean_assert(e == c);

View file

@ -64,6 +64,7 @@ static void tst3() {
int main() {
save_stack_info();
scoped_expr_caching set(false);
tst1();
tst2();
tst3();

View file

@ -93,6 +93,7 @@ assert(f(a, b):abstract(a) == f(Var(0), b))
assert(f(a, b):abstract({a, b}) == f(Var(1), Var(0)))
assert(a:occurs(f(a)))
enable_expr_caching(false)
assert(not f(a):is_eqp(f(a)))
assert(f(a):arg():is_eqp(a))
assert(f(a):depth() == 2)