Add max_shared flag to expr_cell. Improve app constructor.
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
2a9d0de57b
commit
5aa25a635f
5 changed files with 122 additions and 76 deletions
|
@ -20,7 +20,8 @@ unsigned hash_vars(unsigned size, uvar const * vars) {
|
||||||
}
|
}
|
||||||
|
|
||||||
expr_cell::expr_cell(expr_kind k, unsigned h):
|
expr_cell::expr_cell(expr_kind k, unsigned h):
|
||||||
m_kind(k),
|
m_kind(static_cast<unsigned>(k)),
|
||||||
|
m_max_shared(0),
|
||||||
m_hash(h),
|
m_hash(h),
|
||||||
m_rc(1) {}
|
m_rc(1) {}
|
||||||
|
|
||||||
|
@ -103,7 +104,7 @@ expr_numeral::expr_numeral(mpz const & n):
|
||||||
m_numeral(n) {}
|
m_numeral(n) {}
|
||||||
|
|
||||||
void expr_cell::dealloc() {
|
void expr_cell::dealloc() {
|
||||||
switch (m_kind) {
|
switch (kind()) {
|
||||||
case expr_kind::Var: delete static_cast<expr_var*>(this); break;
|
case expr_kind::Var: delete static_cast<expr_var*>(this); break;
|
||||||
case expr_kind::Constant: delete static_cast<expr_const*>(this); break;
|
case expr_kind::Constant: delete static_cast<expr_const*>(this); break;
|
||||||
case expr_kind::App: static_cast<expr_app*>(this)->~expr_app(); delete[] reinterpret_cast<char*>(this); break;
|
case expr_kind::App: static_cast<expr_app*>(this)->~expr_app(); delete[] reinterpret_cast<char*>(this); break;
|
||||||
|
|
|
@ -33,6 +33,8 @@ The main API is divided in the following sections
|
||||||
======================================= */
|
======================================= */
|
||||||
enum class expr_kind { Var, Constant, App, Lambda, Pi, Prop, Type, Numeral };
|
enum class expr_kind { Var, Constant, App, Lambda, Pi, Prop, Type, Numeral };
|
||||||
|
|
||||||
|
class max_sharing_functor;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
\brief Base class used to represent expressions.
|
\brief Base class used to represent expressions.
|
||||||
|
|
||||||
|
@ -42,13 +44,17 @@ enum class expr_kind { Var, Constant, App, Lambda, Pi, Prop, Type, Numeral };
|
||||||
*/
|
*/
|
||||||
class expr_cell {
|
class expr_cell {
|
||||||
protected:
|
protected:
|
||||||
expr_kind m_kind;
|
unsigned m_kind:16;
|
||||||
unsigned m_hash;
|
unsigned m_max_shared:1; // flag indicating if the cell has maximally shared subexpressions
|
||||||
|
unsigned m_hash;
|
||||||
MK_LEAN_RC(); // Declare m_rc counter
|
MK_LEAN_RC(); // Declare m_rc counter
|
||||||
void dealloc();
|
void dealloc();
|
||||||
|
bool max_shared() const { return m_max_shared == 1; }
|
||||||
|
void set_max_shared() { lean_assert(!max_shared()); m_max_shared = 1; }
|
||||||
|
friend class max_sharing_functor;
|
||||||
public:
|
public:
|
||||||
expr_cell(expr_kind k, unsigned h);
|
expr_cell(expr_kind k, unsigned h);
|
||||||
expr_kind kind() const { return m_kind; }
|
expr_kind kind() const { return static_cast<expr_kind>(m_kind); }
|
||||||
unsigned hash() const { return m_hash; }
|
unsigned hash() const { return m_hash; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -119,7 +125,7 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
// =======================================
|
// =======================================
|
||||||
// Expr Representation
|
// Expr (internal) Representation
|
||||||
// 1. Free variables
|
// 1. Free variables
|
||||||
class expr_var : public expr_cell {
|
class expr_var : public expr_cell {
|
||||||
unsigned m_vidx; // de Bruijn index
|
unsigned m_vidx; // de Bruijn index
|
||||||
|
@ -225,17 +231,21 @@ inline expr var(unsigned idx) { return expr(new expr_var(idx)); }
|
||||||
inline expr constant(name const & n) { return expr(new expr_const(n)); }
|
inline expr constant(name const & n) { return expr(new expr_const(n)); }
|
||||||
inline expr constant(name const & n, unsigned pos) { return expr(new expr_const(n, pos)); }
|
inline expr constant(name const & n, unsigned pos) { return expr(new expr_const(n, pos)); }
|
||||||
expr app(unsigned num_args, expr const * args);
|
expr app(unsigned num_args, expr const * args);
|
||||||
inline expr app(std::initializer_list<expr> const & l) { return app(l.size(), l.begin()); }
|
inline expr app(expr const & e1, expr const & e2) { expr args[2] = {e1, e2}; return app(2, args); }
|
||||||
|
inline expr app(expr const & e1, expr const & e2, expr const & e3) { expr args[3] = {e1, e2, e3}; return app(3, args); }
|
||||||
|
inline expr app(expr const & e1, expr const & e2, expr const & e3, expr const & e4) { expr args[4] = {e1, e2, e3, e4}; return app(4, args); }
|
||||||
|
inline expr app(expr const & e1, expr const & e2, expr const & e3, expr const & e4, expr const & e5) { expr args[5] = {e1, e2, e3, e4, e5}; return app(5, args); }
|
||||||
inline expr lambda(name const & n, expr const & t, expr const & e) { return expr(new expr_lambda(n, t, e)); }
|
inline expr lambda(name const & n, expr const & t, expr const & e) { return expr(new expr_lambda(n, t, e)); }
|
||||||
inline expr pi(name const & n, expr const & t, expr const & e) { return expr(new expr_pi(n, t, e)); }
|
inline expr pi(name const & n, expr const & t, expr const & e) { return expr(new expr_pi(n, t, e)); }
|
||||||
inline expr prop() { return expr(new expr_prop()); }
|
inline expr prop() { return expr(new expr_prop()); }
|
||||||
expr type(unsigned size, uvar const * vars);
|
expr type(unsigned size, uvar const * vars);
|
||||||
|
inline expr type(uvar const & uv) { return type(1, &uv); }
|
||||||
inline expr type(std::initializer_list<uvar> const & l) { return type(l.size(), l.begin()); }
|
inline expr type(std::initializer_list<uvar> const & l) { return type(l.size(), l.begin()); }
|
||||||
inline expr numeral(mpz const & n) { return expr(new expr_numeral(n)); }
|
inline expr numeral(mpz const & n) { return expr(new expr_numeral(n)); }
|
||||||
// =======================================
|
// =======================================
|
||||||
|
|
||||||
// =======================================
|
// =======================================
|
||||||
// Casting
|
// Casting (these functions are only needed for low-level code)
|
||||||
inline expr_var * to_var(expr_cell * e) { lean_assert(is_var(e)); return static_cast<expr_var*>(e); }
|
inline expr_var * to_var(expr_cell * e) { lean_assert(is_var(e)); return static_cast<expr_var*>(e); }
|
||||||
inline expr_const * to_constant(expr_cell * e) { lean_assert(is_constant(e)); return static_cast<expr_const*>(e); }
|
inline expr_const * to_constant(expr_cell * e) { lean_assert(is_constant(e)); return static_cast<expr_const*>(e); }
|
||||||
inline expr_app * to_app(expr_cell * e) { lean_assert(is_app(e)); return static_cast<expr_app*>(e); }
|
inline expr_app * to_app(expr_cell * e) { lean_assert(is_app(e)); return static_cast<expr_app*>(e); }
|
||||||
|
|
|
@ -6,65 +6,84 @@ Author: Leonardo de Moura
|
||||||
*/
|
*/
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "expr.h"
|
#include "expr_max_shared.h"
|
||||||
#include "expr_functors.h"
|
#include "expr_functors.h"
|
||||||
|
|
||||||
namespace lean {
|
namespace lean {
|
||||||
namespace max_shared_ns {
|
|
||||||
struct expr_struct_eq { unsigned operator()(expr const & e1, expr const & e2) const { return e1 == e2; }};
|
class max_sharing_functor {
|
||||||
typedef typename std::unordered_set<expr, expr_hash, expr_struct_eq> expr_cache;
|
struct expr_struct_eq { unsigned operator()(expr const & e1, expr const & e2) const { return e1 == e2; }};
|
||||||
static thread_local expr_cache g_cache;
|
typedef typename std::unordered_set<expr, expr_hash, expr_struct_eq> expr_cache;
|
||||||
expr apply(expr const & a) {
|
|
||||||
auto r = g_cache.find(a);
|
expr_cache m_cache;
|
||||||
if (r != g_cache.end())
|
|
||||||
return *r;
|
void cache(expr const & a) {
|
||||||
switch (a.kind()) {
|
a.raw()->set_max_shared();
|
||||||
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral:
|
m_cache.insert(a);
|
||||||
g_cache.insert(a);
|
|
||||||
return a;
|
|
||||||
case expr_kind::App: {
|
|
||||||
std::vector<expr> new_args;
|
|
||||||
bool modified = false;
|
|
||||||
for (expr const & old_arg : app_args(a)) {
|
|
||||||
new_args.push_back(apply(old_arg));
|
|
||||||
if (!eqp(old_arg, new_args.back()))
|
|
||||||
modified = true;
|
|
||||||
}
|
|
||||||
if (!modified) {
|
|
||||||
g_cache.insert(a);
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
expr r = app(static_cast<unsigned>(new_args.size()), new_args.data());
|
|
||||||
g_cache.insert(r);
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
case expr_kind::Lambda:
|
|
||||||
case expr_kind::Pi: {
|
public:
|
||||||
expr const & old_t = get_abs_type(a);
|
|
||||||
expr const & old_e = get_abs_expr(a);
|
expr apply(expr const & a) {
|
||||||
expr t = apply(old_t);
|
if (a.raw()->max_shared())
|
||||||
expr e = apply(old_e);
|
|
||||||
if (!eqp(t, old_t) || !eqp(e, old_e)) {
|
|
||||||
name const & n = get_abs_name(a);
|
|
||||||
expr r = is_pi(a) ? pi(n, t, e) : lambda(n, t, e);
|
|
||||||
g_cache.insert(r);
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
g_cache.insert(a);
|
|
||||||
return a;
|
return a;
|
||||||
|
auto r = m_cache.find(a);
|
||||||
|
if (r != m_cache.end()) {
|
||||||
|
lean_assert((*r).raw()->max_shared());
|
||||||
|
return *r;
|
||||||
}
|
}
|
||||||
}}
|
switch (a.kind()) {
|
||||||
lean_unreachable();
|
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral:
|
||||||
return a;
|
cache(a);
|
||||||
}
|
return a;
|
||||||
} // namespace max_shared_ns
|
case expr_kind::App: {
|
||||||
expr max_shared(expr const & a) {
|
std::vector<expr> new_args;
|
||||||
max_shared_ns::g_cache.clear();
|
bool modified = false;
|
||||||
expr r = max_shared_ns::apply(a);
|
for (expr const & old_arg : app_args(a)) {
|
||||||
max_shared_ns::g_cache.clear();
|
new_args.push_back(apply(old_arg));
|
||||||
return r;
|
if (!eqp(old_arg, new_args.back()))
|
||||||
|
modified = true;
|
||||||
|
}
|
||||||
|
if (!modified) {
|
||||||
|
cache(a);
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
expr r = app(new_args.size(), new_args.data());
|
||||||
|
cache(r);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case expr_kind::Lambda:
|
||||||
|
case expr_kind::Pi: {
|
||||||
|
expr const & old_t = get_abs_type(a);
|
||||||
|
expr const & old_e = get_abs_expr(a);
|
||||||
|
expr t = apply(old_t);
|
||||||
|
expr e = apply(old_e);
|
||||||
|
if (!eqp(t, old_t) || !eqp(e, old_e)) {
|
||||||
|
name const & n = get_abs_name(a);
|
||||||
|
expr r = is_pi(a) ? pi(n, t, e) : lambda(n, t, e);
|
||||||
|
cache(r);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
cache(a);
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
lean_unreachable();
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
expr max_sharing(expr const & a) {
|
||||||
|
if (a.raw()->max_shared()) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
max_sharing_functor f;
|
||||||
|
return f.apply(a);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace lean
|
} // namespace lean
|
||||||
|
|
|
@ -9,7 +9,8 @@ Author: Leonardo de Moura
|
||||||
|
|
||||||
namespace lean {
|
namespace lean {
|
||||||
/**
|
/**
|
||||||
\brief Return a structurally identical expression that is maximally shared.
|
\brief The resultant expression is structurally identical to the input one, but
|
||||||
|
it uses maximally shared sub-expressions.
|
||||||
*/
|
*/
|
||||||
expr max_shared(expr const & a);
|
expr max_sharing(expr const & a);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,15 +16,17 @@ void tst1() {
|
||||||
a = numeral(mpz(10));
|
a = numeral(mpz(10));
|
||||||
expr f;
|
expr f;
|
||||||
f = var(0);
|
f = var(0);
|
||||||
expr fa = app({f, a});
|
expr fa = app(f, a);
|
||||||
std::cout << fa << "\n";
|
std::cout << fa << "\n";
|
||||||
std::cout << app({fa, a}) << "\n";
|
std::cout << app(fa, a) << "\n";
|
||||||
lean_assert(eqp(get_arg(fa, 0), f));
|
lean_assert(eqp(get_arg(fa, 0), f));
|
||||||
lean_assert(eqp(get_arg(fa, 1), a));
|
lean_assert(eqp(get_arg(fa, 1), a));
|
||||||
lean_assert(!eqp(fa, app({f, a})));
|
lean_assert(!eqp(fa, app(f, a)));
|
||||||
lean_assert(app({fa, a}) == app({f, a, a}));
|
lean_assert(app(fa, a) == app(f, a, a));
|
||||||
std::cout << app({fa, fa, fa}) << "\n";
|
std::cout << app(fa, fa, fa) << "\n";
|
||||||
std::cout << lambda(name("x"), prop(), var(0)) << "\n";
|
std::cout << lambda(name("x"), prop(), var(0)) << "\n";
|
||||||
|
lean_assert(app(app(f, a), a) == app(f, a, a));
|
||||||
|
lean_assert(app(f, app(a, a)) != app(f, a, a));
|
||||||
}
|
}
|
||||||
|
|
||||||
expr mk_dag(unsigned depth) {
|
expr mk_dag(unsigned depth) {
|
||||||
|
@ -32,7 +34,7 @@ expr mk_dag(unsigned depth) {
|
||||||
expr a = var(0);
|
expr a = var(0);
|
||||||
while (depth > 0) {
|
while (depth > 0) {
|
||||||
depth--;
|
depth--;
|
||||||
a = app({f, a, a});
|
a = app(f, a, a);
|
||||||
}
|
}
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
|
@ -110,7 +112,7 @@ expr mk_big(expr f, unsigned depth, unsigned val) {
|
||||||
if (depth == 1)
|
if (depth == 1)
|
||||||
return var(val);
|
return var(val);
|
||||||
else
|
else
|
||||||
return app({f, mk_big(f, depth - 1, val << 1), mk_big(f, depth - 1, (val << 1) + 1)});
|
return app(f, mk_big(f, depth - 1, val << 1), mk_big(f, depth - 1, (val << 1) + 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
void tst3() {
|
void tst3() {
|
||||||
|
@ -124,7 +126,7 @@ void tst4() {
|
||||||
expr f = constant(name("f"));
|
expr f = constant(name("f"));
|
||||||
expr a = var(0);
|
expr a = var(0);
|
||||||
for (unsigned i = 0; i < 10000; i++) {
|
for (unsigned i = 0; i < 10000; i++) {
|
||||||
a = app({f, a});
|
a = app(f, a);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,7 +134,7 @@ expr mk_redundant_dag(expr f, unsigned depth) {
|
||||||
if (depth == 0)
|
if (depth == 0)
|
||||||
return var(0);
|
return var(0);
|
||||||
else
|
else
|
||||||
return app({f, mk_redundant_dag(f, depth - 1), mk_redundant_dag(f, depth - 1)});
|
return app(f, mk_redundant_dag(f, depth - 1), mk_redundant_dag(f, depth - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -161,20 +163,32 @@ void tst5() {
|
||||||
expr f = constant(name("f"));
|
expr f = constant(name("f"));
|
||||||
{
|
{
|
||||||
expr r1 = mk_redundant_dag(f, 5);
|
expr r1 = mk_redundant_dag(f, 5);
|
||||||
expr r2 = max_shared(r1);
|
expr r2 = max_sharing(r1);
|
||||||
std::cout << "count(r1): " << count(r1) << "\n";
|
std::cout << "count(r1): " << count(r1) << "\n";
|
||||||
std::cout << "count(r2): " << count(r2) << "\n";
|
std::cout << "count(r2): " << count(r2) << "\n";
|
||||||
lean_assert(r1 == r2);
|
lean_assert(r1 == r2);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
expr r1 = mk_redundant_dag(f, 16);
|
expr r1 = mk_redundant_dag(f, 16);
|
||||||
expr r2 = max_shared(r1);
|
expr r2 = max_sharing(r1);
|
||||||
lean_assert(r1 == r2);
|
lean_assert(r1 == r2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void tst6() {
|
||||||
|
expr f = constant(name("f"));
|
||||||
|
expr r = mk_redundant_dag(f, 12);
|
||||||
|
for (unsigned i = 0; i < 1000; i++) {
|
||||||
|
r = max_sharing(r);
|
||||||
|
}
|
||||||
|
r = mk_big(f, 16, 0);
|
||||||
|
for (unsigned i = 0; i < 1000000; i++) {
|
||||||
|
r = max_sharing(r);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
// continue_on_violation(true);
|
continue_on_violation(true);
|
||||||
std::cout << "sizeof(expr): " << sizeof(expr) << "\n";
|
std::cout << "sizeof(expr): " << sizeof(expr) << "\n";
|
||||||
std::cout << "sizeof(expr_app): " << sizeof(expr_app) << "\n";
|
std::cout << "sizeof(expr_app): " << sizeof(expr_app) << "\n";
|
||||||
tst1();
|
tst1();
|
||||||
|
@ -182,6 +196,7 @@ int main() {
|
||||||
tst3();
|
tst3();
|
||||||
tst4();
|
tst4();
|
||||||
tst5();
|
tst5();
|
||||||
|
tst6();
|
||||||
std::cout << "done" << "\n";
|
std::cout << "done" << "\n";
|
||||||
return has_violations() ? 1 : 0;
|
return has_violations() ? 1 : 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue