fix(library/expr_lt): use expression depth instead of size to obtain a monotonic total order on terms
It is not incorrect to use size, but it can easily overflow due to sharing. The following script demonstrates the problem: local f = Const("f") local a = Const("a") function mk_shared(d) if d == 0 then return a else local c = mk_shared(d-1) return f(c, c) end end print(mk_shared(33):size()) Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
cd19d4da01
commit
6bcd8e3ee5
5 changed files with 30 additions and 37 deletions
|
@ -8,6 +8,7 @@ Author: Leonardo de Moura
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
#include "util/hash.h"
|
#include "util/hash.h"
|
||||||
#include "util/buffer.h"
|
#include "util/buffer.h"
|
||||||
#include "util/object_serializer.h"
|
#include "util/object_serializer.h"
|
||||||
|
@ -129,21 +130,21 @@ expr mk_app(unsigned n, expr const * as) {
|
||||||
expr * m_args = to_app(r)->m_args;
|
expr * m_args = to_app(r)->m_args;
|
||||||
unsigned i = 0;
|
unsigned i = 0;
|
||||||
unsigned j = 0;
|
unsigned j = 0;
|
||||||
unsigned total_size = 1;
|
unsigned depth = 0;
|
||||||
if (new_n != n) {
|
if (new_n != n) {
|
||||||
for (; i < n0; ++i) {
|
for (; i < n0; ++i) {
|
||||||
new (m_args+i) expr(arg(arg0, i));
|
new (m_args+i) expr(arg(arg0, i));
|
||||||
total_size += get_size(m_args[i]);
|
depth = std::max(depth, get_depth(m_args[i]));
|
||||||
}
|
}
|
||||||
j++;
|
j++;
|
||||||
}
|
}
|
||||||
for (; i < new_n; ++i, ++j) {
|
for (; i < new_n; ++i, ++j) {
|
||||||
lean_assert(j < n);
|
lean_assert(j < n);
|
||||||
new (m_args+i) expr(as[j]);
|
new (m_args+i) expr(as[j]);
|
||||||
total_size += get_size(m_args[i]);
|
depth = std::max(depth, get_depth(m_args[i]));
|
||||||
}
|
}
|
||||||
to_app(r)->m_hash = hash_args(new_n, m_args);
|
to_app(r)->m_hash = hash_args(new_n, m_args);
|
||||||
to_app(r)->m_total_size = total_size;
|
to_app(r)->m_depth = depth + 1;
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
expr_abstraction::expr_abstraction(expr_kind k, name const & n, expr const & t, expr const & b):
|
expr_abstraction::expr_abstraction(expr_kind k, name const & n, expr const & t, expr const & b):
|
||||||
|
@ -151,7 +152,7 @@ expr_abstraction::expr_abstraction(expr_kind k, name const & n, expr const & t,
|
||||||
m_name(n),
|
m_name(n),
|
||||||
m_domain(t),
|
m_domain(t),
|
||||||
m_body(b) {
|
m_body(b) {
|
||||||
m_total_size = 1 + get_size(m_domain) + get_size(m_body);
|
m_depth = 1 + std::max(get_depth(m_domain), get_depth(m_body));
|
||||||
}
|
}
|
||||||
void expr_abstraction::dealloc(buffer<expr_cell*> & todelete) {
|
void expr_abstraction::dealloc(buffer<expr_cell*> & todelete) {
|
||||||
dec_ref(m_body, todelete);
|
dec_ref(m_body, todelete);
|
||||||
|
@ -171,9 +172,10 @@ expr_let::expr_let(name const & n, optional<expr> const & t, expr const & v, exp
|
||||||
m_type(t),
|
m_type(t),
|
||||||
m_value(v),
|
m_value(v),
|
||||||
m_body(b) {
|
m_body(b) {
|
||||||
m_total_size = 1 + get_size(m_value) + get_size(m_body);
|
unsigned depth = std::max(get_depth(m_value), get_depth(m_body));
|
||||||
if (m_type)
|
if (m_type)
|
||||||
m_total_size += get_size(*m_type);
|
depth = std::max(depth, get_depth(*m_type));
|
||||||
|
m_depth = 1 + depth;
|
||||||
}
|
}
|
||||||
void expr_let::dealloc(buffer<expr_cell*> & todelete) {
|
void expr_let::dealloc(buffer<expr_cell*> & todelete) {
|
||||||
dec_ref(m_body, todelete);
|
dec_ref(m_body, todelete);
|
||||||
|
@ -275,17 +277,17 @@ bool is_arrow(expr const & t) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned get_size(expr const & e) {
|
unsigned get_depth(expr const & e) {
|
||||||
switch (e.kind()) {
|
switch (e.kind()) {
|
||||||
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Type:
|
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Type:
|
||||||
case expr_kind::Value: case expr_kind::MetaVar:
|
case expr_kind::Value: case expr_kind::MetaVar:
|
||||||
return 1;
|
return 1;
|
||||||
case expr_kind::App:
|
case expr_kind::App:
|
||||||
return to_app(e)->m_total_size;
|
return to_app(e)->m_depth;
|
||||||
case expr_kind::Pi: case expr_kind::Lambda:
|
case expr_kind::Pi: case expr_kind::Lambda:
|
||||||
return to_abstraction(e)->m_total_size;
|
return to_abstraction(e)->m_depth;
|
||||||
case expr_kind::Let:
|
case expr_kind::Let:
|
||||||
return to_let(e)->m_total_size;
|
return to_let(e)->m_depth;
|
||||||
}
|
}
|
||||||
lean_unreachable(); // LCOV_EXCL_LINE
|
lean_unreachable(); // LCOV_EXCL_LINE
|
||||||
}
|
}
|
||||||
|
|
|
@ -199,13 +199,13 @@ public:
|
||||||
};
|
};
|
||||||
/** \brief Function Applications */
|
/** \brief Function Applications */
|
||||||
class expr_app : public expr_cell {
|
class expr_app : public expr_cell {
|
||||||
unsigned m_total_size;
|
unsigned m_depth;
|
||||||
unsigned m_num_args;
|
unsigned m_num_args;
|
||||||
expr m_args[0];
|
expr m_args[0];
|
||||||
friend expr mk_app(unsigned num_args, expr const * args);
|
friend expr mk_app(unsigned num_args, expr const * args);
|
||||||
friend expr_cell;
|
friend expr_cell;
|
||||||
void dealloc(buffer<expr_cell*> & todelete);
|
void dealloc(buffer<expr_cell*> & todelete);
|
||||||
friend unsigned get_size(expr const & e);
|
friend unsigned get_depth(expr const & e);
|
||||||
public:
|
public:
|
||||||
expr_app(unsigned size, bool has_mv);
|
expr_app(unsigned size, bool has_mv);
|
||||||
unsigned get_num_args() const { return m_num_args; }
|
unsigned get_num_args() const { return m_num_args; }
|
||||||
|
@ -215,13 +215,13 @@ public:
|
||||||
};
|
};
|
||||||
/** \brief Super class for lambda abstraction and pi (functional spaces). */
|
/** \brief Super class for lambda abstraction and pi (functional spaces). */
|
||||||
class expr_abstraction : public expr_cell {
|
class expr_abstraction : public expr_cell {
|
||||||
unsigned m_total_size;
|
unsigned m_depth;
|
||||||
name m_name;
|
name m_name;
|
||||||
expr m_domain;
|
expr m_domain;
|
||||||
expr m_body;
|
expr m_body;
|
||||||
friend class expr_cell;
|
friend class expr_cell;
|
||||||
void dealloc(buffer<expr_cell*> & todelete);
|
void dealloc(buffer<expr_cell*> & todelete);
|
||||||
friend unsigned get_size(expr const & e);
|
friend unsigned get_depth(expr const & e);
|
||||||
public:
|
public:
|
||||||
expr_abstraction(expr_kind k, name const & n, expr const & t, expr const & e);
|
expr_abstraction(expr_kind k, name const & n, expr const & t, expr const & e);
|
||||||
name const & get_name() const { return m_name; }
|
name const & get_name() const { return m_name; }
|
||||||
|
@ -240,14 +240,14 @@ public:
|
||||||
};
|
};
|
||||||
/** \brief Let expressions */
|
/** \brief Let expressions */
|
||||||
class expr_let : public expr_cell {
|
class expr_let : public expr_cell {
|
||||||
unsigned m_total_size;
|
unsigned m_depth;
|
||||||
name m_name;
|
name m_name;
|
||||||
optional<expr> m_type;
|
optional<expr> m_type;
|
||||||
expr m_value;
|
expr m_value;
|
||||||
expr m_body;
|
expr m_body;
|
||||||
friend class expr_cell;
|
friend class expr_cell;
|
||||||
void dealloc(buffer<expr_cell*> & todelete);
|
void dealloc(buffer<expr_cell*> & todelete);
|
||||||
friend unsigned get_size(expr const & e);
|
friend unsigned get_depth(expr const & e);
|
||||||
public:
|
public:
|
||||||
expr_let(name const & n, optional<expr> const & t, expr const & v, expr const & b);
|
expr_let(name const & n, optional<expr> const & t, expr const & v, expr const & b);
|
||||||
~expr_let();
|
~expr_let();
|
||||||
|
@ -523,8 +523,8 @@ inline expr const & let_value(expr const & e) { return to_let(e)->ge
|
||||||
inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); }
|
inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); }
|
||||||
inline name const & metavar_name(expr const & e) { return to_metavar(e)->get_name(); }
|
inline name const & metavar_name(expr const & e) { return to_metavar(e)->get_name(); }
|
||||||
inline local_context const & metavar_lctx(expr const & e) { return to_metavar(e)->get_lctx(); }
|
inline local_context const & metavar_lctx(expr const & e) { return to_metavar(e)->get_lctx(); }
|
||||||
/** \brief Return the size of the given expression */
|
/** \brief Return the depth of the given expression */
|
||||||
unsigned get_size(expr const & e);
|
unsigned get_depth(expr const & e);
|
||||||
|
|
||||||
inline bool has_metavar(expr const & e) { return e.has_metavar(); }
|
inline bool has_metavar(expr const & e) { return e.has_metavar(); }
|
||||||
// =======================================
|
// =======================================
|
||||||
|
|
|
@ -17,10 +17,10 @@ static bool is_lt(optional<expr> const & a, optional<expr> const & b, bool use_h
|
||||||
|
|
||||||
bool is_lt(expr const & a, expr const & b, bool use_hash) {
|
bool is_lt(expr const & a, expr const & b, bool use_hash) {
|
||||||
if (is_eqp(a, b)) return false;
|
if (is_eqp(a, b)) return false;
|
||||||
unsigned a_sz = get_size(a);
|
unsigned da = get_depth(a);
|
||||||
unsigned b_sz = get_size(b);
|
unsigned db = get_depth(b);
|
||||||
if (a_sz < b_sz) return true;
|
if (da < db) return true;
|
||||||
if (a_sz > b_sz) return false;
|
if (da > db) return false;
|
||||||
if (a.kind() != b.kind()) return a.kind() < b.kind();
|
if (a.kind() != b.kind()) return a.kind() < b.kind();
|
||||||
if (use_hash) {
|
if (use_hash) {
|
||||||
if (a.hash() < b.hash()) return true;
|
if (a.hash() < b.hash()) return true;
|
||||||
|
|
|
@ -631,8 +631,8 @@ static int expr_mk_eq(lua_State * L) {
|
||||||
return push_expr(L, mk_eq(to_expr(L, 1), to_expr(L, 2), to_expr(L, 3)));
|
return push_expr(L, mk_eq(to_expr(L, 1), to_expr(L, 2), to_expr(L, 3)));
|
||||||
}
|
}
|
||||||
|
|
||||||
static int expr_size(lua_State * L) {
|
static int expr_depth(lua_State * L) {
|
||||||
lua_pushinteger(L, get_size(to_expr(L, 1)));
|
lua_pushinteger(L, get_depth(to_expr(L, 1)));
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -661,7 +661,7 @@ static const struct luaL_Reg expr_m[] = {
|
||||||
{"data", safe_function<expr_fields>},
|
{"data", safe_function<expr_fields>},
|
||||||
{"args", safe_function<expr_args>},
|
{"args", safe_function<expr_args>},
|
||||||
{"num_args", safe_function<expr_num_args>},
|
{"num_args", safe_function<expr_num_args>},
|
||||||
{"size", safe_function<expr_size>},
|
{"depth", safe_function<expr_depth>},
|
||||||
{"arg", safe_function<expr_arg>},
|
{"arg", safe_function<expr_arg>},
|
||||||
{"abst_name", safe_function<expr_abst_name>},
|
{"abst_name", safe_function<expr_abst_name>},
|
||||||
{"abst_domain", safe_function<expr_abst_domain>},
|
{"abst_domain", safe_function<expr_abst_domain>},
|
||||||
|
|
|
@ -1,9 +0,0 @@
|
||||||
local f = Const("f")
|
|
||||||
assert(f:size() == 1)
|
|
||||||
local x = Var(0)
|
|
||||||
assert(x:size() == 1)
|
|
||||||
assert(f(x):size() == 3)
|
|
||||||
local t = f(x)
|
|
||||||
assert(t(x):size() == 4)
|
|
||||||
assert(t(x)(f(x)):size() == 7)
|
|
||||||
assert(mk_lambda("x", Const("Nat"), f(x)):size() == 5)
|
|
Loading…
Reference in a new issue