diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index b9be760b2..f3393b0dd 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -129,16 +129,21 @@ expr mk_app(unsigned n, expr const * as) { expr * m_args = to_app(r)->m_args; unsigned i = 0; unsigned j = 0; + unsigned total_size = 1; if (new_n != n) { - for (; i < n0; i++) + for (; i < n0; ++i) { new (m_args+i) expr(arg(arg0, i)); + total_size += get_size(m_args[i]); + } j++; } for (; i < new_n; ++i, ++j) { lean_assert(j < n); new (m_args+i) expr(as[j]); + total_size += get_size(m_args[i]); } - to_app(r)->m_hash = hash_args(new_n, m_args); + to_app(r)->m_hash = hash_args(new_n, m_args); + to_app(r)->m_total_size = total_size; return r; } expr_abstraction::expr_abstraction(expr_kind k, name const & n, expr const & t, expr const & b): @@ -146,6 +151,7 @@ expr_abstraction::expr_abstraction(expr_kind k, name const & n, expr const & t, m_name(n), m_domain(t), m_body(b) { + m_total_size = 1 + get_size(m_domain) + get_size(m_body); } void expr_abstraction::dealloc(buffer & todelete) { dec_ref(m_body, todelete); @@ -165,6 +171,9 @@ expr_let::expr_let(name const & n, optional const & t, expr const & v, exp m_type(t), m_value(v), m_body(b) { + m_total_size = 1 + get_size(m_value) + get_size(m_body); + if (m_type) + m_total_size += get_size(*m_type); } void expr_let::dealloc(buffer & todelete) { dec_ref(m_body, todelete); @@ -266,6 +275,20 @@ bool is_arrow(expr const & t) { } } +unsigned get_size(expr const & e) { + switch (e.kind()) { + case expr_kind::Var: case expr_kind::Constant: case expr_kind::Type: + case expr_kind::Value: case expr_kind::MetaVar: + return 1; + case expr_kind::App: + return to_app(e)->m_total_size; + case expr_kind::Pi: case expr_kind::Lambda: + return to_abstraction(e)->m_total_size; + case expr_kind::Let: + return to_let(e)->m_total_size; + } +} + expr copy(expr const & a) { switch (a.kind()) { case expr_kind::Var: return mk_var(var_idx(a)); diff --git a/src/kernel/expr.h b/src/kernel/expr.h index ae58afa0d..8649f9ed2 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -75,7 +75,7 @@ protected: // 2 - term contains metavariables // 3-4 - term is an arrow (0 - not initialized, 1 - is arrow, 2 - is not arrow) atomic_ushort m_flags; - unsigned m_hash; // hash based on the structure of the expression (this is a good hash for structural equality) + unsigned m_hash; // hash based on the structure of the expression (this is a good hash for structural equality) unsigned m_hash_alloc; // hash based on 'time' of allocation (this is a good hash for pointer-based equality) MK_LEAN_RC(); // Declare m_rc counter void dealloc(); @@ -199,11 +199,13 @@ public: }; /** \brief Function Applications */ class expr_app : public expr_cell { + unsigned m_total_size; unsigned m_num_args; expr m_args[0]; friend expr mk_app(unsigned num_args, expr const * args); friend expr_cell; void dealloc(buffer & todelete); + friend unsigned get_size(expr const & e); public: expr_app(unsigned size, bool has_mv); unsigned get_num_args() const { return m_num_args; } @@ -213,11 +215,13 @@ public: }; /** \brief Super class for lambda abstraction and pi (functional spaces). */ class expr_abstraction : public expr_cell { + unsigned m_total_size; name m_name; expr m_domain; expr m_body; friend class expr_cell; void dealloc(buffer & todelete); + friend unsigned get_size(expr const & e); public: expr_abstraction(expr_kind k, name const & n, expr const & t, expr const & e); name const & get_name() const { return m_name; } @@ -236,19 +240,21 @@ public: }; /** \brief Let expressions */ class expr_let : public expr_cell { + unsigned m_total_size; name m_name; optional m_type; expr m_value; expr m_body; friend class expr_cell; void dealloc(buffer & todelete); + friend unsigned get_size(expr const & e); public: expr_let(name const & n, optional const & t, expr const & v, expr const & b); ~expr_let(); - name const & get_name() const { return m_name; } + name const & get_name() const { return m_name; } optional const & get_type() const { return m_type; } - expr const & get_value() const { return m_value; } - expr const & get_body() const { return m_body; } + expr const & get_value() const { return m_value; } + expr const & get_body() const { return m_body; } }; /** \brief Type */ class expr_type : public expr_cell { @@ -517,6 +523,8 @@ inline expr const & let_value(expr const & e) { return to_let(e)->ge inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); } inline name const & metavar_name(expr const & e) { return to_metavar(e)->get_name(); } inline local_context const & metavar_lctx(expr const & e) { return to_metavar(e)->get_lctx(); } +/** \brief Return the size of the given expression */ +unsigned get_size(expr const & e); inline bool has_metavar(expr const & e) { return e.has_metavar(); } // ======================================= diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index d061b28d3..5394fa0ff 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -631,6 +631,11 @@ static int expr_mk_eq(lua_State * L) { return push_expr(L, mk_eq(to_expr(L, 1), to_expr(L, 2), to_expr(L, 3))); } +static int expr_size(lua_State * L) { + lua_pushinteger(L, get_size(to_expr(L, 1))); + return 1; +} + static const struct luaL_Reg expr_m[] = { {"__gc", expr_gc}, // never throws {"__tostring", safe_function}, @@ -651,6 +656,7 @@ static const struct luaL_Reg expr_m[] = { {"data", safe_function}, {"args", safe_function}, {"num_args", safe_function}, + {"size", safe_function}, {"arg", safe_function}, {"abst_name", safe_function}, {"abst_domain", safe_function}, diff --git a/tests/lua/size.lua b/tests/lua/size.lua new file mode 100644 index 000000000..6b6b437af --- /dev/null +++ b/tests/lua/size.lua @@ -0,0 +1,9 @@ +local f = Const("f") +assert(f:size() == 1) +local x = Var(0) +assert(x:size() == 1) +assert(f(x):size() == 3) +local t = f(x) +assert(t(x):size() == 4) +assert(t(x)(f(x)):size() == 7) +assert(mk_lambda("x", Const("Nat"), f(x)):size() == 5)