From 6bcd8e3ee5c88f66d93b10e6593824ee6b851a25 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 20 Jan 2014 17:40:44 -0800 Subject: [PATCH] fix(library/expr_lt): use expression depth instead of size to obtain a monotonic total order on terms It is not incorrect to use size, but it can easily overflow due to sharing. The following script demonstrates the problem: local f = Const("f") local a = Const("a") function mk_shared(d) if d == 0 then return a else local c = mk_shared(d-1) return f(c, c) end end print(mk_shared(33):size()) Signed-off-by: Leonardo de Moura --- src/kernel/expr.cpp | 26 ++++++++++++++------------ src/kernel/expr.h | 16 ++++++++-------- src/library/expr_lt.cpp | 8 ++++---- src/library/kernel_bindings.cpp | 8 ++++---- tests/lua/size.lua | 9 --------- 5 files changed, 30 insertions(+), 37 deletions(-) delete mode 100644 tests/lua/size.lua diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index 8be111439..a39f7d40f 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -8,6 +8,7 @@ Author: Leonardo de Moura #include #include #include +#include #include "util/hash.h" #include "util/buffer.h" #include "util/object_serializer.h" @@ -129,21 +130,21 @@ expr mk_app(unsigned n, expr const * as) { expr * m_args = to_app(r)->m_args; unsigned i = 0; unsigned j = 0; - unsigned total_size = 1; + unsigned depth = 0; if (new_n != n) { for (; i < n0; ++i) { new (m_args+i) expr(arg(arg0, i)); - total_size += get_size(m_args[i]); + depth = std::max(depth, get_depth(m_args[i])); } j++; } for (; i < new_n; ++i, ++j) { lean_assert(j < n); new (m_args+i) expr(as[j]); - total_size += get_size(m_args[i]); + depth = std::max(depth, get_depth(m_args[i])); } - to_app(r)->m_hash = hash_args(new_n, m_args); - to_app(r)->m_total_size = total_size; + to_app(r)->m_hash = hash_args(new_n, m_args); + to_app(r)->m_depth = depth + 1; return r; } expr_abstraction::expr_abstraction(expr_kind k, name const & n, expr const & t, expr const & b): @@ -151,7 +152,7 @@ expr_abstraction::expr_abstraction(expr_kind k, name const & n, expr const & t, m_name(n), m_domain(t), m_body(b) { - m_total_size = 1 + get_size(m_domain) + get_size(m_body); + m_depth = 1 + std::max(get_depth(m_domain), get_depth(m_body)); } void expr_abstraction::dealloc(buffer & todelete) { dec_ref(m_body, todelete); @@ -171,9 +172,10 @@ expr_let::expr_let(name const & n, optional const & t, expr const & v, exp m_type(t), m_value(v), m_body(b) { - m_total_size = 1 + get_size(m_value) + get_size(m_body); + unsigned depth = std::max(get_depth(m_value), get_depth(m_body)); if (m_type) - m_total_size += get_size(*m_type); + depth = std::max(depth, get_depth(*m_type)); + m_depth = 1 + depth; } void expr_let::dealloc(buffer & todelete) { dec_ref(m_body, todelete); @@ -275,17 +277,17 @@ bool is_arrow(expr const & t) { } } -unsigned get_size(expr const & e) { +unsigned get_depth(expr const & e) { switch (e.kind()) { case expr_kind::Var: case expr_kind::Constant: case expr_kind::Type: case expr_kind::Value: case expr_kind::MetaVar: return 1; case expr_kind::App: - return to_app(e)->m_total_size; + return to_app(e)->m_depth; case expr_kind::Pi: case expr_kind::Lambda: - return to_abstraction(e)->m_total_size; + return to_abstraction(e)->m_depth; case expr_kind::Let: - return to_let(e)->m_total_size; + return to_let(e)->m_depth; } lean_unreachable(); // LCOV_EXCL_LINE } diff --git a/src/kernel/expr.h b/src/kernel/expr.h index d5afa3931..1130aa3cd 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -199,13 +199,13 @@ public: }; /** \brief Function Applications */ class expr_app : public expr_cell { - unsigned m_total_size; + unsigned m_depth; unsigned m_num_args; expr m_args[0]; friend expr mk_app(unsigned num_args, expr const * args); friend expr_cell; void dealloc(buffer & todelete); - friend unsigned get_size(expr const & e); + friend unsigned get_depth(expr const & e); public: expr_app(unsigned size, bool has_mv); unsigned get_num_args() const { return m_num_args; } @@ -215,13 +215,13 @@ public: }; /** \brief Super class for lambda abstraction and pi (functional spaces). */ class expr_abstraction : public expr_cell { - unsigned m_total_size; + unsigned m_depth; name m_name; expr m_domain; expr m_body; friend class expr_cell; void dealloc(buffer & todelete); - friend unsigned get_size(expr const & e); + friend unsigned get_depth(expr const & e); public: expr_abstraction(expr_kind k, name const & n, expr const & t, expr const & e); name const & get_name() const { return m_name; } @@ -240,14 +240,14 @@ public: }; /** \brief Let expressions */ class expr_let : public expr_cell { - unsigned m_total_size; + unsigned m_depth; name m_name; optional m_type; expr m_value; expr m_body; friend class expr_cell; void dealloc(buffer & todelete); - friend unsigned get_size(expr const & e); + friend unsigned get_depth(expr const & e); public: expr_let(name const & n, optional const & t, expr const & v, expr const & b); ~expr_let(); @@ -523,8 +523,8 @@ inline expr const & let_value(expr const & e) { return to_let(e)->ge inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); } inline name const & metavar_name(expr const & e) { return to_metavar(e)->get_name(); } inline local_context const & metavar_lctx(expr const & e) { return to_metavar(e)->get_lctx(); } -/** \brief Return the size of the given expression */ -unsigned get_size(expr const & e); +/** \brief Return the depth of the given expression */ +unsigned get_depth(expr const & e); inline bool has_metavar(expr const & e) { return e.has_metavar(); } // ======================================= diff --git a/src/library/expr_lt.cpp b/src/library/expr_lt.cpp index 460cac125..f2808ec2b 100644 --- a/src/library/expr_lt.cpp +++ b/src/library/expr_lt.cpp @@ -17,10 +17,10 @@ static bool is_lt(optional const & a, optional const & b, bool use_h bool is_lt(expr const & a, expr const & b, bool use_hash) { if (is_eqp(a, b)) return false; - unsigned a_sz = get_size(a); - unsigned b_sz = get_size(b); - if (a_sz < b_sz) return true; - if (a_sz > b_sz) return false; + unsigned da = get_depth(a); + unsigned db = get_depth(b); + if (da < db) return true; + if (da > db) return false; if (a.kind() != b.kind()) return a.kind() < b.kind(); if (use_hash) { if (a.hash() < b.hash()) return true; diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index 582ee12b0..3b4720abc 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -631,8 +631,8 @@ static int expr_mk_eq(lua_State * L) { return push_expr(L, mk_eq(to_expr(L, 1), to_expr(L, 2), to_expr(L, 3))); } -static int expr_size(lua_State * L) { - lua_pushinteger(L, get_size(to_expr(L, 1))); +static int expr_depth(lua_State * L) { + lua_pushinteger(L, get_depth(to_expr(L, 1))); return 1; } @@ -651,7 +651,7 @@ static const struct luaL_Reg expr_m[] = { {"is_var", safe_function}, {"is_constant", safe_function}, {"is_app", safe_function}, - {"is_lambda", safe_function}, + {"is_lambda", safe_function}, {"is_pi", safe_function}, {"is_abstraction", safe_function}, {"is_let", safe_function}, @@ -661,7 +661,7 @@ static const struct luaL_Reg expr_m[] = { {"data", safe_function}, {"args", safe_function}, {"num_args", safe_function}, - {"size", safe_function}, + {"depth", safe_function}, {"arg", safe_function}, {"abst_name", safe_function}, {"abst_domain", safe_function}, diff --git a/tests/lua/size.lua b/tests/lua/size.lua deleted file mode 100644 index 6b6b437af..000000000 --- a/tests/lua/size.lua +++ /dev/null @@ -1,9 +0,0 @@ -local f = Const("f") -assert(f:size() == 1) -local x = Var(0) -assert(x:size() == 1) -assert(f(x):size() == 3) -local t = f(x) -assert(t(x):size() == 4) -assert(t(x)(f(x)):size() == 7) -assert(mk_lambda("x", Const("Nat"), f(x)):size() == 5)