feat(kernel/expr): add efficient get_size() function for expressions

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-01-20 12:28:37 -08:00
parent 5224df56a3
commit ac9f8f340d
4 changed files with 52 additions and 6 deletions

View file

@ -129,16 +129,21 @@ expr mk_app(unsigned n, expr const * as) {
expr * m_args = to_app(r)->m_args; expr * m_args = to_app(r)->m_args;
unsigned i = 0; unsigned i = 0;
unsigned j = 0; unsigned j = 0;
unsigned total_size = 1;
if (new_n != n) { if (new_n != n) {
for (; i < n0; i++) for (; i < n0; ++i) {
new (m_args+i) expr(arg(arg0, i)); new (m_args+i) expr(arg(arg0, i));
total_size += get_size(m_args[i]);
}
j++; j++;
} }
for (; i < new_n; ++i, ++j) { for (; i < new_n; ++i, ++j) {
lean_assert(j < n); lean_assert(j < n);
new (m_args+i) expr(as[j]); new (m_args+i) expr(as[j]);
total_size += get_size(m_args[i]);
} }
to_app(r)->m_hash = hash_args(new_n, m_args); to_app(r)->m_hash = hash_args(new_n, m_args);
to_app(r)->m_total_size = total_size;
return r; return r;
} }
expr_abstraction::expr_abstraction(expr_kind k, name const & n, expr const & t, expr const & b): expr_abstraction::expr_abstraction(expr_kind k, name const & n, expr const & t, expr const & b):
@ -146,6 +151,7 @@ expr_abstraction::expr_abstraction(expr_kind k, name const & n, expr const & t,
m_name(n), m_name(n),
m_domain(t), m_domain(t),
m_body(b) { m_body(b) {
m_total_size = 1 + get_size(m_domain) + get_size(m_body);
} }
void expr_abstraction::dealloc(buffer<expr_cell*> & todelete) { void expr_abstraction::dealloc(buffer<expr_cell*> & todelete) {
dec_ref(m_body, todelete); dec_ref(m_body, todelete);
@ -165,6 +171,9 @@ expr_let::expr_let(name const & n, optional<expr> const & t, expr const & v, exp
m_type(t), m_type(t),
m_value(v), m_value(v),
m_body(b) { m_body(b) {
m_total_size = 1 + get_size(m_value) + get_size(m_body);
if (m_type)
m_total_size += get_size(*m_type);
} }
void expr_let::dealloc(buffer<expr_cell*> & todelete) { void expr_let::dealloc(buffer<expr_cell*> & todelete) {
dec_ref(m_body, todelete); dec_ref(m_body, todelete);
@ -266,6 +275,20 @@ bool is_arrow(expr const & t) {
} }
} }
unsigned get_size(expr const & e) {
switch (e.kind()) {
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Type:
case expr_kind::Value: case expr_kind::MetaVar:
return 1;
case expr_kind::App:
return to_app(e)->m_total_size;
case expr_kind::Pi: case expr_kind::Lambda:
return to_abstraction(e)->m_total_size;
case expr_kind::Let:
return to_let(e)->m_total_size;
}
}
expr copy(expr const & a) { expr copy(expr const & a) {
switch (a.kind()) { switch (a.kind()) {
case expr_kind::Var: return mk_var(var_idx(a)); case expr_kind::Var: return mk_var(var_idx(a));

View file

@ -75,7 +75,7 @@ protected:
// 2 - term contains metavariables // 2 - term contains metavariables
// 3-4 - term is an arrow (0 - not initialized, 1 - is arrow, 2 - is not arrow) // 3-4 - term is an arrow (0 - not initialized, 1 - is arrow, 2 - is not arrow)
atomic_ushort m_flags; atomic_ushort m_flags;
unsigned m_hash; // hash based on the structure of the expression (this is a good hash for structural equality) unsigned m_hash; // hash based on the structure of the expression (this is a good hash for structural equality)
unsigned m_hash_alloc; // hash based on 'time' of allocation (this is a good hash for pointer-based equality) unsigned m_hash_alloc; // hash based on 'time' of allocation (this is a good hash for pointer-based equality)
MK_LEAN_RC(); // Declare m_rc counter MK_LEAN_RC(); // Declare m_rc counter
void dealloc(); void dealloc();
@ -199,11 +199,13 @@ public:
}; };
/** \brief Function Applications */ /** \brief Function Applications */
class expr_app : public expr_cell { class expr_app : public expr_cell {
unsigned m_total_size;
unsigned m_num_args; unsigned m_num_args;
expr m_args[0]; expr m_args[0];
friend expr mk_app(unsigned num_args, expr const * args); friend expr mk_app(unsigned num_args, expr const * args);
friend expr_cell; friend expr_cell;
void dealloc(buffer<expr_cell*> & todelete); void dealloc(buffer<expr_cell*> & todelete);
friend unsigned get_size(expr const & e);
public: public:
expr_app(unsigned size, bool has_mv); expr_app(unsigned size, bool has_mv);
unsigned get_num_args() const { return m_num_args; } unsigned get_num_args() const { return m_num_args; }
@ -213,11 +215,13 @@ public:
}; };
/** \brief Super class for lambda abstraction and pi (functional spaces). */ /** \brief Super class for lambda abstraction and pi (functional spaces). */
class expr_abstraction : public expr_cell { class expr_abstraction : public expr_cell {
unsigned m_total_size;
name m_name; name m_name;
expr m_domain; expr m_domain;
expr m_body; expr m_body;
friend class expr_cell; friend class expr_cell;
void dealloc(buffer<expr_cell*> & todelete); void dealloc(buffer<expr_cell*> & todelete);
friend unsigned get_size(expr const & e);
public: public:
expr_abstraction(expr_kind k, name const & n, expr const & t, expr const & e); expr_abstraction(expr_kind k, name const & n, expr const & t, expr const & e);
name const & get_name() const { return m_name; } name const & get_name() const { return m_name; }
@ -236,19 +240,21 @@ public:
}; };
/** \brief Let expressions */ /** \brief Let expressions */
class expr_let : public expr_cell { class expr_let : public expr_cell {
unsigned m_total_size;
name m_name; name m_name;
optional<expr> m_type; optional<expr> m_type;
expr m_value; expr m_value;
expr m_body; expr m_body;
friend class expr_cell; friend class expr_cell;
void dealloc(buffer<expr_cell*> & todelete); void dealloc(buffer<expr_cell*> & todelete);
friend unsigned get_size(expr const & e);
public: public:
expr_let(name const & n, optional<expr> const & t, expr const & v, expr const & b); expr_let(name const & n, optional<expr> const & t, expr const & v, expr const & b);
~expr_let(); ~expr_let();
name const & get_name() const { return m_name; } name const & get_name() const { return m_name; }
optional<expr> const & get_type() const { return m_type; } optional<expr> const & get_type() const { return m_type; }
expr const & get_value() const { return m_value; } expr const & get_value() const { return m_value; }
expr const & get_body() const { return m_body; } expr const & get_body() const { return m_body; }
}; };
/** \brief Type */ /** \brief Type */
class expr_type : public expr_cell { class expr_type : public expr_cell {
@ -517,6 +523,8 @@ inline expr const & let_value(expr const & e) { return to_let(e)->ge
inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); } inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); }
inline name const & metavar_name(expr const & e) { return to_metavar(e)->get_name(); } inline name const & metavar_name(expr const & e) { return to_metavar(e)->get_name(); }
inline local_context const & metavar_lctx(expr const & e) { return to_metavar(e)->get_lctx(); } inline local_context const & metavar_lctx(expr const & e) { return to_metavar(e)->get_lctx(); }
/** \brief Return the size of the given expression */
unsigned get_size(expr const & e);
inline bool has_metavar(expr const & e) { return e.has_metavar(); } inline bool has_metavar(expr const & e) { return e.has_metavar(); }
// ======================================= // =======================================

View file

@ -631,6 +631,11 @@ static int expr_mk_eq(lua_State * L) {
return push_expr(L, mk_eq(to_expr(L, 1), to_expr(L, 2), to_expr(L, 3))); return push_expr(L, mk_eq(to_expr(L, 1), to_expr(L, 2), to_expr(L, 3)));
} }
static int expr_size(lua_State * L) {
lua_pushinteger(L, get_size(to_expr(L, 1)));
return 1;
}
static const struct luaL_Reg expr_m[] = { static const struct luaL_Reg expr_m[] = {
{"__gc", expr_gc}, // never throws {"__gc", expr_gc}, // never throws
{"__tostring", safe_function<expr_tostring>}, {"__tostring", safe_function<expr_tostring>},
@ -651,6 +656,7 @@ static const struct luaL_Reg expr_m[] = {
{"data", safe_function<expr_fields>}, {"data", safe_function<expr_fields>},
{"args", safe_function<expr_args>}, {"args", safe_function<expr_args>},
{"num_args", safe_function<expr_num_args>}, {"num_args", safe_function<expr_num_args>},
{"size", safe_function<expr_size>},
{"arg", safe_function<expr_arg>}, {"arg", safe_function<expr_arg>},
{"abst_name", safe_function<expr_abst_name>}, {"abst_name", safe_function<expr_abst_name>},
{"abst_domain", safe_function<expr_abst_domain>}, {"abst_domain", safe_function<expr_abst_domain>},

9
tests/lua/size.lua Normal file
View file

@ -0,0 +1,9 @@
local f = Const("f")
assert(f:size() == 1)
local x = Var(0)
assert(x:size() == 1)
assert(f(x):size() == 3)
local t = f(x)
assert(t(x):size() == 4)
assert(t(x)(f(x)):size() == 7)
assert(mk_lambda("x", Const("Nat"), f(x)):size() == 5)