refactor(library/blast): simplify blast/expr

This commit is contained in:
Leonardo de Moura 2015-09-30 12:54:03 -07:00
parent c5603e456a
commit ac7c0fffd8
2 changed files with 17 additions and 344 deletions

View file

@ -13,34 +13,26 @@ Author: Leonardo de Moura
#include "kernel/instantiate_univ_cache.h" #include "kernel/instantiate_univ_cache.h"
#include "library/blast/expr.h" #include "library/blast/expr.h"
#ifndef LEAN_DEFAULT_BLAST_REPLACE_CACHE_CAPACITY
#define LEAN_DEFAULT_BLAST_REPLACE_CACHE_CAPACITY 1024*8
#endif
#ifndef LEAN_BLAST_INST_UNIV_CACHE_SIZE #ifndef LEAN_BLAST_INST_UNIV_CACHE_SIZE
#define LEAN_BLAST_INST_UNIV_CACHE_SIZE 1023 #define LEAN_BLAST_INST_UNIV_CACHE_SIZE 1023
#endif #endif
namespace lean { namespace lean {
namespace blast { namespace blast {
typedef typename std::unordered_set<expr, expr_hash, is_bi_equal_proc> expr_table;
typedef typename std::unordered_set<level, level_hash> level_table; typedef typename std::unordered_set<level, level_hash> level_table;
typedef typename std::vector<expr> expr_array; typedef typename std::vector<expr> expr_array;
LEAN_THREAD_PTR(level_table, g_level_table); LEAN_THREAD_PTR(level_table, g_level_table);
LEAN_THREAD_PTR(expr_table, g_expr_table);
LEAN_THREAD_PTR(expr_array, g_var_array); LEAN_THREAD_PTR(expr_array, g_var_array);
LEAN_THREAD_PTR(expr_array, g_mref_array); LEAN_THREAD_PTR(expr_array, g_mref_array);
LEAN_THREAD_PTR(expr_array, g_href_array); LEAN_THREAD_PTR(expr_array, g_href_array);
scope_hash_consing::scope_hash_consing() { scope_hash_consing::scope_hash_consing() {
m_level_table = g_level_table; lean_assert(g_level_table == nullptr);
m_expr_table = g_expr_table; lean_assert(g_var_array == nullptr);
m_var_array = g_var_array; lean_assert(g_mref_array == nullptr);
m_mref_array = g_mref_array; lean_assert(g_href_array == nullptr);
m_href_array = g_href_array;
g_level_table = new level_table(); g_level_table = new level_table();
g_expr_table = new expr_table();
g_var_array = new expr_array(); g_var_array = new expr_array();
g_mref_array = new expr_array(); g_mref_array = new expr_array();
g_href_array = new expr_array(); g_href_array = new expr_array();
@ -50,15 +42,13 @@ scope_hash_consing::scope_hash_consing() {
scope_hash_consing::~scope_hash_consing() { scope_hash_consing::~scope_hash_consing() {
delete g_level_table; delete g_level_table;
delete g_expr_table;
delete g_var_array; delete g_var_array;
delete g_mref_array; delete g_mref_array;
delete g_href_array; delete g_href_array;
g_level_table = reinterpret_cast<level_table*>(m_level_table); g_level_table = nullptr;
g_expr_table = reinterpret_cast<expr_table*>(m_expr_table); g_var_array = nullptr;
g_var_array = reinterpret_cast<expr_array*>(m_var_array); g_mref_array = nullptr;
g_mref_array = reinterpret_cast<expr_array*>(m_mref_array); g_href_array = nullptr;
g_href_array = reinterpret_cast<expr_array*>(m_href_array);
} }
#ifdef LEAN_DEBUG #ifdef LEAN_DEBUG
@ -66,11 +56,6 @@ static bool is_cached(level const & l) {
lean_assert(g_level_table); lean_assert(g_level_table);
return g_level_table->find(l) != g_level_table->end(); return g_level_table->find(l) != g_level_table->end();
} }
static bool is_cached(expr const & l) {
lean_assert(g_expr_table);
return g_expr_table->find(l) != g_expr_table->end();
}
#endif #endif
static level cache(level const & l) { static level cache(level const & l) {
@ -82,15 +67,6 @@ static level cache(level const & l) {
return l; return l;
} }
static expr cache(expr const & e) {
lean_assert(g_expr_table);
auto r = g_expr_table->find(e);
if (r != g_expr_table->end())
return *r;
g_expr_table->insert(e);
return e;
}
level mk_level_zero() { level mk_level_zero() {
return lean::mk_level_zero(); return lean::mk_level_zero();
} }
@ -158,12 +134,10 @@ static expr mk_href_core(unsigned idx) {
expr mk_href(unsigned idx) { expr mk_href(unsigned idx) {
lean_assert(g_href_array); lean_assert(g_href_array);
lean_assert(g_expr_table);
while (g_href_array->size() <= idx) { while (g_href_array->size() <= idx) {
unsigned j = g_href_array->size(); unsigned j = g_href_array->size();
expr new_ref = mk_href_core(j); expr new_ref = mk_href_core(j);
g_href_array->push_back(new_ref); g_href_array->push_back(new_ref);
g_expr_table->insert(new_ref);
} }
lean_assert(idx < g_href_array->size()); lean_assert(idx < g_href_array->size());
return (*g_href_array)[idx]; return (*g_href_array)[idx];
@ -179,12 +153,10 @@ static expr mk_mref_core(unsigned idx) {
expr mk_mref(unsigned idx) { expr mk_mref(unsigned idx) {
lean_assert(g_mref_array); lean_assert(g_mref_array);
lean_assert(g_expr_table);
while (g_mref_array->size() <= idx) { while (g_mref_array->size() <= idx) {
unsigned j = g_mref_array->size(); unsigned j = g_mref_array->size();
expr new_ref = mk_mref_core(j); expr new_ref = mk_mref_core(j);
g_mref_array->push_back(new_ref); g_mref_array->push_back(new_ref);
g_expr_table->insert(new_ref);
} }
lean_assert(idx < g_mref_array->size()); lean_assert(idx < g_mref_array->size());
return (*g_mref_array)[idx]; return (*g_mref_array)[idx];
@ -214,7 +186,7 @@ bool has_mref(expr const & e) {
expr mk_local(unsigned idx, expr const & t) { expr mk_local(unsigned idx, expr const & t) {
lean_assert(is_cached(t)); lean_assert(is_cached(t));
return cache(lean::mk_local(name(*g_prefix, idx), t)); return lean::mk_local(name(*g_prefix, idx), t);
} }
bool is_local(expr const & e) { bool is_local(expr const & e) {
@ -237,12 +209,10 @@ expr const & local_type(expr const & e) {
expr mk_var(unsigned idx) { expr mk_var(unsigned idx) {
lean_assert(g_var_array); lean_assert(g_var_array);
lean_assert(g_expr_table);
while (g_var_array->size() <= idx) { while (g_var_array->size() <= idx) {
unsigned j = g_var_array->size(); unsigned j = g_var_array->size();
expr new_var = lean::mk_var(j); expr new_var = lean::mk_var(j);
g_var_array->push_back(new_var); g_var_array->push_back(new_var);
g_expr_table->insert(new_var);
} }
lean_assert(idx < g_var_array->size()); lean_assert(idx < g_var_array->size());
return (*g_var_array)[idx]; return (*g_var_array)[idx];
@ -251,7 +221,7 @@ expr mk_var(unsigned idx) {
expr mk_app(expr const & f, expr const & a) { expr mk_app(expr const & f, expr const & a) {
lean_assert(is_cached(f)); lean_assert(is_cached(f));
lean_assert(is_cached(a)); lean_assert(is_cached(a));
return cache(lean::mk_app(f, a)); return lean::mk_app(f, a);
} }
expr mk_app(expr const & f, unsigned num_args, expr const * args) { expr mk_app(expr const & f, unsigned num_args, expr const * args) {
@ -268,23 +238,23 @@ expr mk_app(unsigned num_args, expr const * args) {
expr mk_sort(level const & l) { expr mk_sort(level const & l) {
lean_assert(is_cached(l)); lean_assert(is_cached(l));
return cache(lean::mk_sort(l)); return lean::mk_sort(l);
} }
expr mk_constant(name const & n, levels const & ls) { expr mk_constant(name const & n, levels const & ls) {
lean_assert(std::all_of(ls.begin(), ls.end(), [](level const & l) { return is_cached(l); })); lean_assert(std::all_of(ls.begin(), ls.end(), [](level const & l) { return is_cached(l); }));
return cache(lean::mk_constant(n, ls)); return lean::mk_constant(n, ls);
} }
expr mk_binding(expr_kind k, name const & n, expr const & t, expr const & e, binder_info const & bi) { expr mk_binding(expr_kind k, name const & n, expr const & t, expr const & e, binder_info const & bi) {
lean_assert(is_cached(t)); lean_assert(is_cached(t));
lean_assert(is_cached(e)); lean_assert(is_cached(e));
return cache(lean::mk_binding(k, n, t, e, bi)); return lean::mk_binding(k, n, t, e, bi);
} }
expr mk_macro(macro_definition const & m, unsigned num, expr const * args) { expr mk_macro(macro_definition const & m, unsigned num, expr const * args) {
lean_assert(std::all_of(args, args+num, [](expr const & e) { return is_cached(e); })); lean_assert(std::all_of(args, args+num, [](expr const & e) { return is_cached(e); }));
return cache(lean::mk_macro(m, num, args)); return lean::mk_macro(m, num, args);
} }
expr update_app(expr const & e, expr const & new_fn, expr const & new_arg) { expr update_app(expr const & e, expr const & new_fn, expr const & new_arg) {
@ -328,299 +298,6 @@ expr update_macro(expr const & e, unsigned num, expr const * args) {
return blast::mk_macro(to_macro(e)->get_def(), num, args); return blast::mk_macro(to_macro(e)->get_def(), num, args);
} }
MK_CACHE_STACK(replace_cache, LEAN_DEFAULT_BLAST_REPLACE_CACHE_CAPACITY)
class replace_rec_fn {
replace_cache_ref m_cache;
std::function<optional<expr>(expr const &, unsigned)> m_f;
expr save_result(expr const & e, unsigned offset, expr const & r) {
m_cache->insert(e, offset, r);
return r;
}
expr apply(expr const & e, unsigned offset) {
if (auto r = m_cache->find(e, offset))
return *r;
check_interrupted();
check_memory("replace");
if (optional<expr> r = m_f(e, offset)) {
return save_result(e, offset, *r);
} else {
switch (e.kind()) {
case expr_kind::Constant: case expr_kind::Sort: case expr_kind::Var:
return save_result(e, offset, e);
case expr_kind::Meta:
lean_assert(is_mref(e));
return save_result(e, offset, e);
case expr_kind::Local:
lean_assert(is_href(e));
return save_result(e, offset, e);
case expr_kind::App: {
expr new_f = apply(app_fn(e), offset);
expr new_a = apply(app_arg(e), offset);
return save_result(e, offset, blast::update_app(e, new_f, new_a));
}
case expr_kind::Pi: case expr_kind::Lambda: {
expr new_d = apply(binding_domain(e), offset);
expr new_b = apply(binding_body(e), offset+1);
return save_result(e, offset, blast::update_binding(e, new_d, new_b));
}
case expr_kind::Macro:
if (macro_num_args(e) == 0) {
return save_result(e, offset, e);
} else {
buffer<expr> new_args;
unsigned nargs = macro_num_args(e);
for (unsigned i = 0; i < nargs; i++)
new_args.push_back(apply(macro_arg(e, i), offset));
return save_result(e, offset, blast::update_macro(e, new_args.size(), new_args.data()));
}
}
lean_unreachable();
}
}
public:
template<typename F>
replace_rec_fn(F const & f):m_f(f) {}
expr operator()(expr const & e) { return apply(e, 0); }
};
expr replace(expr const & e, std::function<optional<expr>(expr const &, unsigned)> const & f) {
return replace_rec_fn(f)(e);
}
expr lift_free_vars(expr const & e, unsigned s, unsigned d) {
if (d == 0 || s >= get_free_var_range(e))
return e;
return blast::replace(e, [=](expr const & e, unsigned offset) -> optional<expr> {
unsigned s1 = s + offset;
if (s1 < s)
return some_expr(e); // overflow, vidx can't be >= max unsigned
if (s1 >= get_free_var_range(e))
return some_expr(e); // expression e does not contain free variables with idx >= s1
if (is_var(e) && var_idx(e) >= s + offset) {
unsigned new_idx = var_idx(e) + d;
if (new_idx < var_idx(e))
throw exception("invalid lift_free_vars operation, index overflow");
return some_expr(blast::mk_var(new_idx));
} else {
return none_expr();
}
});
}
expr lift_free_vars(expr const & e, unsigned d) {
return blast::lift_free_vars(e, 0, d);
}
template<bool rev>
struct instantiate_easy_fn {
unsigned n;
expr const * subst;
instantiate_easy_fn(unsigned _n, expr const * _subst):n(_n), subst(_subst) {}
optional<expr> operator()(expr const & a, bool app) const {
if (closed(a))
return some_expr(a);
if (is_var(a) && var_idx(a) < n)
return some_expr(subst[rev ? n - var_idx(a) - 1 : var_idx(a)]);
if (app && is_app(a))
if (auto new_a = operator()(app_arg(a), false))
if (auto new_f = operator()(app_fn(a), true))
return some_expr(blast::mk_app(*new_f, *new_a));
return none_expr();
}
};
expr instantiate(expr const & a, unsigned s, unsigned n, expr const * subst) {
if (s >= get_free_var_range(a) || n == 0)
return a;
if (s == 0)
if (auto r = blast::instantiate_easy_fn<false>(n, subst)(a, true))
return *r;
return blast::replace(a, [=](expr const & m, unsigned offset) -> optional<expr> {
unsigned s1 = s + offset;
if (s1 < s)
return some_expr(m); // overflow, vidx can't be >= max unsigned
if (s1 >= get_free_var_range(m))
return some_expr(m); // expression m does not contain free variables with idx >= s1
if (is_var(m)) {
unsigned vidx = var_idx(m);
if (vidx >= s1) {
unsigned h = s1 + n;
if (h < s1 /* overflow, h is bigger than any vidx */ || vidx < h) {
return some_expr(blast::lift_free_vars(subst[vidx - s1], offset));
} else {
return some_expr(blast::mk_var(vidx - n));
}
}
}
return none_expr();
});
}
expr instantiate(expr const & e, unsigned n, expr const * s) {
return blast::instantiate(e, 0, n, s);
}
expr instantiate_rev(expr const & a, unsigned n, expr const * subst) {
if (closed(a))
return a;
if (auto r = blast::instantiate_easy_fn<true>(n, subst)(a, true))
return *r;
return blast::replace(a, [=](expr const & m, unsigned offset) -> optional<expr> {
if (offset >= get_free_var_range(m))
return some_expr(m); // expression m does not contain free variables with idx >= offset
if (is_var(m)) {
unsigned vidx = var_idx(m);
if (vidx >= offset) {
unsigned h = offset + n;
if (h < offset /* overflow, h is bigger than any vidx */ || vidx < h) {
return some_expr(blast::lift_free_vars(subst[n - (vidx - offset) - 1], offset));
} else {
return some_expr(blast::mk_var(vidx - n));
}
}
}
return none_expr();
});
}
class replace_level_fn {
std::function<optional<level>(level const &)> m_f;
level apply(level const & l) {
optional<level> r = m_f(l);
if (r)
return *r;
switch (l.kind()) {
case level_kind::Succ:
return blast::update_succ(l, apply(succ_of(l)));
case level_kind::Max:
return blast::update_max(l, apply(max_lhs(l)), apply(max_rhs(l)));
case level_kind::IMax:
return blast::update_max(l, apply(imax_lhs(l)), apply(imax_rhs(l)));
case level_kind::Zero: case level_kind::Param: case level_kind::Meta: case level_kind::Global:
return l;
}
lean_unreachable(); // LCOV_EXCL_LINE
}
public:
template<typename F> replace_level_fn(F const & f):m_f(f) {}
level operator()(level const & l) { return apply(l); }
};
level replace(level const & l, std::function<optional<level>(level const & l)> const & f) {
return replace_level_fn(f)(l);
}
level instantiate(level const & l, level_param_names const & ps, levels const & ls) {
lean_assert(length(ps) == length(ls));
return blast::replace(l, [=](level const & l) {
if (!has_param(l)) {
return some_level(l);
} else if (is_param(l)) {
name const & id = param_id(l);
list<name> const *it1 = &ps;
list<level> const * it2 = &ls;
while (!is_nil(*it1)) {
if (head(*it1) == id)
return some_level(head(*it2));
it1 = &tail(*it1);
it2 = &tail(*it2);
}
return some_level(l);
} else {
return none_level();
}
});
}
expr instantiate_univ_params(expr const & e, level_param_names const & ps, levels const & ls) {
if (!has_param_univ(e))
return e;
return blast::replace(e, [&](expr const & e) -> optional<expr> {
if (!has_param_univ(e))
return some_expr(e);
if (is_constant(e)) {
levels new_ls = map_reuse(const_levels(e),
[&](level const & l) { return blast::instantiate(l, ps, ls); },
[](level const & l1, level const & l2) { return is_eqp(l1, l2); });
return some_expr(blast::update_constant(e, new_ls));
} else if (is_sort(e)) {
return some_expr(blast::update_sort(e, blast::instantiate(sort_level(e), ps, ls)));
} else {
return none_expr();
}
});
}
MK_THREAD_LOCAL_GET(instantiate_univ_cache, get_type_univ_cache, LEAN_BLAST_INST_UNIV_CACHE_SIZE);
MK_THREAD_LOCAL_GET(instantiate_univ_cache, get_value_univ_cache, LEAN_BLAST_INST_UNIV_CACHE_SIZE);
expr instantiate_type_univ_params(declaration const & d, levels const & ls) {
lean_assert(d.get_num_univ_params() == length(ls));
if (is_nil(ls) || !has_param_univ(d.get_type()))
return d.get_type();
instantiate_univ_cache & cache = get_type_univ_cache();
if (auto r = cache.is_cached(d, ls))
return *r;
expr r = blast::instantiate_univ_params(d.get_type(), d.get_univ_params(), ls);
cache.save(d, ls, r);
return r;
}
expr instantiate_value_univ_params(declaration const & d, levels const & ls) {
lean_assert(d.get_num_univ_params() == length(ls));
if (is_nil(ls) || !has_param_univ(d.get_value()))
return d.get_value();
instantiate_univ_cache & cache = get_value_univ_cache();
if (auto r = cache.is_cached(d, ls))
return *r;
expr r = blast::instantiate_univ_params(d.get_value(), d.get_univ_params(), ls);
cache.save(d, ls, r);
return r;
}
expr abstract_hrefs(expr const & e, unsigned n, expr const * subst) {
if (!has_href(e))
return e;
lean_assert(std::all_of(subst, subst+n, [](expr const & e) { return closed(e) && is_href(e); }));
return blast::replace(e, [=](expr const & m, unsigned offset) -> optional<expr> {
if (!has_href(m))
return some_expr(m); // skip: m does not contain href's
if (is_href(m)) {
unsigned i = n;
while (i > 0) {
--i;
if (href_index(subst[i]) == href_index(m))
return some_expr(blast::mk_var(offset + n - i - 1));
}
return none_expr();
}
return none_expr();
});
}
expr abstract_locals(expr const & e, unsigned n, expr const * subst) {
if (!blast::has_local(e))
return e;
lean_assert(std::all_of(subst, subst+n, [](expr const & e) { return closed(e) && blast::is_local(e); }));
return blast::replace(e, [=](expr const & m, unsigned offset) -> optional<expr> {
if (!blast::has_local(m))
return some_expr(m); // skip: m does not contain locals
if (blast::is_local(m)) {
unsigned i = n;
while (i > 0) {
--i;
if (local_index(subst[i]) == local_index(m)) //
return some_expr(blast::mk_var(offset + n - i - 1));
}
return none_expr();
}
return none_expr();
});
}
void initialize_expr() { void initialize_expr() {
g_prefix = new name(name::mk_internal_unique_name()); g_prefix = new name(name::mk_internal_unique_name());
g_dummy_type = new expr(mk_constant(*g_prefix)); g_dummy_type = new expr(mk_constant(*g_prefix));

View file

@ -12,19 +12,15 @@ namespace lean {
namespace blast { namespace blast {
// API for creating maximally shared terms used by the blast tactic. // API for creating maximally shared terms used by the blast tactic.
// The API assumes there is a single blast tactic using theses terms. // The API assumes there is a single blast tactic using theses terms.
// The hash-consing tables are thread local. // The expression hash-consing tables are thread local and implemented
// in the kernel
// Remark: All procedures assume the children levels and expressions are maximally shared. // Remark: All procedures assume the children levels and expressions are maximally shared.
// That is, it assumes they have been created using the APIs provided by this module. // That is, it assumes they have been created using the APIs provided by this module.
// Auxiliary object for resetting the the thread local hash-consing tables. // Auxiliary object for resetting the the thread local hash-consing tables.
// Its destructor restores the state of the hash-consing tables. // It also uses an assertion to make sure it is not being used in a recursion.
class scope_hash_consing { class scope_hash_consing {
void * m_level_table;
void * m_expr_table;
void * m_var_array;
void * m_mref_array;
void * m_href_array;
public: public:
scope_hash_consing(); scope_hash_consing();
~scope_hash_consing(); ~scope_hash_consing();