Add max_shared flag to expr_cell. Improve app constructor.
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
2a9d0de57b
commit
5aa25a635f
5 changed files with 122 additions and 76 deletions
|
@ -20,7 +20,8 @@ unsigned hash_vars(unsigned size, uvar const * vars) {
|
|||
}
|
||||
|
||||
expr_cell::expr_cell(expr_kind k, unsigned h):
|
||||
m_kind(k),
|
||||
m_kind(static_cast<unsigned>(k)),
|
||||
m_max_shared(0),
|
||||
m_hash(h),
|
||||
m_rc(1) {}
|
||||
|
||||
|
@ -103,7 +104,7 @@ expr_numeral::expr_numeral(mpz const & n):
|
|||
m_numeral(n) {}
|
||||
|
||||
void expr_cell::dealloc() {
|
||||
switch (m_kind) {
|
||||
switch (kind()) {
|
||||
case expr_kind::Var: delete static_cast<expr_var*>(this); break;
|
||||
case expr_kind::Constant: delete static_cast<expr_const*>(this); break;
|
||||
case expr_kind::App: static_cast<expr_app*>(this)->~expr_app(); delete[] reinterpret_cast<char*>(this); break;
|
||||
|
|
|
@ -33,6 +33,8 @@ The main API is divided in the following sections
|
|||
======================================= */
|
||||
enum class expr_kind { Var, Constant, App, Lambda, Pi, Prop, Type, Numeral };
|
||||
|
||||
class max_sharing_functor;
|
||||
|
||||
/**
|
||||
\brief Base class used to represent expressions.
|
||||
|
||||
|
@ -42,13 +44,17 @@ enum class expr_kind { Var, Constant, App, Lambda, Pi, Prop, Type, Numeral };
|
|||
*/
|
||||
class expr_cell {
|
||||
protected:
|
||||
expr_kind m_kind;
|
||||
unsigned m_kind:16;
|
||||
unsigned m_max_shared:1; // flag indicating if the cell has maximally shared subexpressions
|
||||
unsigned m_hash;
|
||||
MK_LEAN_RC(); // Declare m_rc counter
|
||||
void dealloc();
|
||||
bool max_shared() const { return m_max_shared == 1; }
|
||||
void set_max_shared() { lean_assert(!max_shared()); m_max_shared = 1; }
|
||||
friend class max_sharing_functor;
|
||||
public:
|
||||
expr_cell(expr_kind k, unsigned h);
|
||||
expr_kind kind() const { return m_kind; }
|
||||
expr_kind kind() const { return static_cast<expr_kind>(m_kind); }
|
||||
unsigned hash() const { return m_hash; }
|
||||
};
|
||||
|
||||
|
@ -119,7 +125,7 @@ public:
|
|||
};
|
||||
|
||||
// =======================================
|
||||
// Expr Representation
|
||||
// Expr (internal) Representation
|
||||
// 1. Free variables
|
||||
class expr_var : public expr_cell {
|
||||
unsigned m_vidx; // de Bruijn index
|
||||
|
@ -225,17 +231,21 @@ inline expr var(unsigned idx) { return expr(new expr_var(idx)); }
|
|||
inline expr constant(name const & n) { return expr(new expr_const(n)); }
|
||||
inline expr constant(name const & n, unsigned pos) { return expr(new expr_const(n, pos)); }
|
||||
expr app(unsigned num_args, expr const * args);
|
||||
inline expr app(std::initializer_list<expr> const & l) { return app(l.size(), l.begin()); }
|
||||
inline expr app(expr const & e1, expr const & e2) { expr args[2] = {e1, e2}; return app(2, args); }
|
||||
inline expr app(expr const & e1, expr const & e2, expr const & e3) { expr args[3] = {e1, e2, e3}; return app(3, args); }
|
||||
inline expr app(expr const & e1, expr const & e2, expr const & e3, expr const & e4) { expr args[4] = {e1, e2, e3, e4}; return app(4, args); }
|
||||
inline expr app(expr const & e1, expr const & e2, expr const & e3, expr const & e4, expr const & e5) { expr args[5] = {e1, e2, e3, e4, e5}; return app(5, args); }
|
||||
inline expr lambda(name const & n, expr const & t, expr const & e) { return expr(new expr_lambda(n, t, e)); }
|
||||
inline expr pi(name const & n, expr const & t, expr const & e) { return expr(new expr_pi(n, t, e)); }
|
||||
inline expr prop() { return expr(new expr_prop()); }
|
||||
expr type(unsigned size, uvar const * vars);
|
||||
inline expr type(uvar const & uv) { return type(1, &uv); }
|
||||
inline expr type(std::initializer_list<uvar> const & l) { return type(l.size(), l.begin()); }
|
||||
inline expr numeral(mpz const & n) { return expr(new expr_numeral(n)); }
|
||||
// =======================================
|
||||
|
||||
// =======================================
|
||||
// Casting
|
||||
// Casting (these functions are only needed for low-level code)
|
||||
inline expr_var * to_var(expr_cell * e) { lean_assert(is_var(e)); return static_cast<expr_var*>(e); }
|
||||
inline expr_const * to_constant(expr_cell * e) { lean_assert(is_constant(e)); return static_cast<expr_const*>(e); }
|
||||
inline expr_app * to_app(expr_cell * e) { lean_assert(is_app(e)); return static_cast<expr_app*>(e); }
|
||||
|
|
|
@ -6,21 +6,35 @@ Author: Leonardo de Moura
|
|||
*/
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "expr.h"
|
||||
#include "expr_max_shared.h"
|
||||
#include "expr_functors.h"
|
||||
|
||||
namespace lean {
|
||||
namespace max_shared_ns {
|
||||
|
||||
class max_sharing_functor {
|
||||
struct expr_struct_eq { unsigned operator()(expr const & e1, expr const & e2) const { return e1 == e2; }};
|
||||
typedef typename std::unordered_set<expr, expr_hash, expr_struct_eq> expr_cache;
|
||||
static thread_local expr_cache g_cache;
|
||||
|
||||
expr_cache m_cache;
|
||||
|
||||
void cache(expr const & a) {
|
||||
a.raw()->set_max_shared();
|
||||
m_cache.insert(a);
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
expr apply(expr const & a) {
|
||||
auto r = g_cache.find(a);
|
||||
if (r != g_cache.end())
|
||||
if (a.raw()->max_shared())
|
||||
return a;
|
||||
auto r = m_cache.find(a);
|
||||
if (r != m_cache.end()) {
|
||||
lean_assert((*r).raw()->max_shared());
|
||||
return *r;
|
||||
}
|
||||
switch (a.kind()) {
|
||||
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral:
|
||||
g_cache.insert(a);
|
||||
cache(a);
|
||||
return a;
|
||||
case expr_kind::App: {
|
||||
std::vector<expr> new_args;
|
||||
|
@ -31,12 +45,12 @@ expr apply(expr const & a) {
|
|||
modified = true;
|
||||
}
|
||||
if (!modified) {
|
||||
g_cache.insert(a);
|
||||
cache(a);
|
||||
return a;
|
||||
}
|
||||
else {
|
||||
expr r = app(static_cast<unsigned>(new_args.size()), new_args.data());
|
||||
g_cache.insert(r);
|
||||
expr r = app(new_args.size(), new_args.data());
|
||||
cache(r);
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
@ -49,22 +63,27 @@ expr apply(expr const & a) {
|
|||
if (!eqp(t, old_t) || !eqp(e, old_e)) {
|
||||
name const & n = get_abs_name(a);
|
||||
expr r = is_pi(a) ? pi(n, t, e) : lambda(n, t, e);
|
||||
g_cache.insert(r);
|
||||
cache(r);
|
||||
return r;
|
||||
}
|
||||
else {
|
||||
g_cache.insert(a);
|
||||
cache(a);
|
||||
return a;
|
||||
}
|
||||
}}
|
||||
lean_unreachable();
|
||||
return a;
|
||||
}
|
||||
} // namespace max_shared_ns
|
||||
expr max_shared(expr const & a) {
|
||||
max_shared_ns::g_cache.clear();
|
||||
expr r = max_shared_ns::apply(a);
|
||||
max_shared_ns::g_cache.clear();
|
||||
return r;
|
||||
};
|
||||
|
||||
expr max_sharing(expr const & a) {
|
||||
if (a.raw()->max_shared()) {
|
||||
return a;
|
||||
}
|
||||
else {
|
||||
max_sharing_functor f;
|
||||
return f.apply(a);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace lean
|
||||
|
|
|
@ -9,7 +9,8 @@ Author: Leonardo de Moura
|
|||
|
||||
namespace lean {
|
||||
/**
|
||||
\brief Return a structurally identical expression that is maximally shared.
|
||||
\brief The resultant expression is structurally identical to the input one, but
|
||||
it uses maximally shared sub-expressions.
|
||||
*/
|
||||
expr max_shared(expr const & a);
|
||||
expr max_sharing(expr const & a);
|
||||
}
|
||||
|
|
|
@ -16,15 +16,17 @@ void tst1() {
|
|||
a = numeral(mpz(10));
|
||||
expr f;
|
||||
f = var(0);
|
||||
expr fa = app({f, a});
|
||||
expr fa = app(f, a);
|
||||
std::cout << fa << "\n";
|
||||
std::cout << app({fa, a}) << "\n";
|
||||
std::cout << app(fa, a) << "\n";
|
||||
lean_assert(eqp(get_arg(fa, 0), f));
|
||||
lean_assert(eqp(get_arg(fa, 1), a));
|
||||
lean_assert(!eqp(fa, app({f, a})));
|
||||
lean_assert(app({fa, a}) == app({f, a, a}));
|
||||
std::cout << app({fa, fa, fa}) << "\n";
|
||||
lean_assert(!eqp(fa, app(f, a)));
|
||||
lean_assert(app(fa, a) == app(f, a, a));
|
||||
std::cout << app(fa, fa, fa) << "\n";
|
||||
std::cout << lambda(name("x"), prop(), var(0)) << "\n";
|
||||
lean_assert(app(app(f, a), a) == app(f, a, a));
|
||||
lean_assert(app(f, app(a, a)) != app(f, a, a));
|
||||
}
|
||||
|
||||
expr mk_dag(unsigned depth) {
|
||||
|
@ -32,7 +34,7 @@ expr mk_dag(unsigned depth) {
|
|||
expr a = var(0);
|
||||
while (depth > 0) {
|
||||
depth--;
|
||||
a = app({f, a, a});
|
||||
a = app(f, a, a);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
@ -110,7 +112,7 @@ expr mk_big(expr f, unsigned depth, unsigned val) {
|
|||
if (depth == 1)
|
||||
return var(val);
|
||||
else
|
||||
return app({f, mk_big(f, depth - 1, val << 1), mk_big(f, depth - 1, (val << 1) + 1)});
|
||||
return app(f, mk_big(f, depth - 1, val << 1), mk_big(f, depth - 1, (val << 1) + 1));
|
||||
}
|
||||
|
||||
void tst3() {
|
||||
|
@ -124,7 +126,7 @@ void tst4() {
|
|||
expr f = constant(name("f"));
|
||||
expr a = var(0);
|
||||
for (unsigned i = 0; i < 10000; i++) {
|
||||
a = app({f, a});
|
||||
a = app(f, a);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -132,7 +134,7 @@ expr mk_redundant_dag(expr f, unsigned depth) {
|
|||
if (depth == 0)
|
||||
return var(0);
|
||||
else
|
||||
return app({f, mk_redundant_dag(f, depth - 1), mk_redundant_dag(f, depth - 1)});
|
||||
return app(f, mk_redundant_dag(f, depth - 1), mk_redundant_dag(f, depth - 1));
|
||||
}
|
||||
|
||||
|
||||
|
@ -161,20 +163,32 @@ void tst5() {
|
|||
expr f = constant(name("f"));
|
||||
{
|
||||
expr r1 = mk_redundant_dag(f, 5);
|
||||
expr r2 = max_shared(r1);
|
||||
expr r2 = max_sharing(r1);
|
||||
std::cout << "count(r1): " << count(r1) << "\n";
|
||||
std::cout << "count(r2): " << count(r2) << "\n";
|
||||
lean_assert(r1 == r2);
|
||||
}
|
||||
{
|
||||
expr r1 = mk_redundant_dag(f, 16);
|
||||
expr r2 = max_shared(r1);
|
||||
expr r2 = max_sharing(r1);
|
||||
lean_assert(r1 == r2);
|
||||
}
|
||||
}
|
||||
|
||||
void tst6() {
|
||||
expr f = constant(name("f"));
|
||||
expr r = mk_redundant_dag(f, 12);
|
||||
for (unsigned i = 0; i < 1000; i++) {
|
||||
r = max_sharing(r);
|
||||
}
|
||||
r = mk_big(f, 16, 0);
|
||||
for (unsigned i = 0; i < 1000000; i++) {
|
||||
r = max_sharing(r);
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
// continue_on_violation(true);
|
||||
continue_on_violation(true);
|
||||
std::cout << "sizeof(expr): " << sizeof(expr) << "\n";
|
||||
std::cout << "sizeof(expr_app): " << sizeof(expr_app) << "\n";
|
||||
tst1();
|
||||
|
@ -182,6 +196,7 @@ int main() {
|
|||
tst3();
|
||||
tst4();
|
||||
tst5();
|
||||
tst6();
|
||||
std::cout << "done" << "\n";
|
||||
return has_violations() ? 1 : 0;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue