diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index 9808d55aa..e30f6f550 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -37,9 +37,11 @@ bool has_meta(levels const & ls) { return false; } -expr_cell::expr_cell(expr_kind k, unsigned h, bool has_mv): +expr_cell::expr_cell(expr_kind k, unsigned h, bool has_mv, bool has_local): m_kind(static_cast(k)), - m_flags(has_mv ? 4 : 0), + m_flags(0), + m_has_mv(has_mv), + m_has_local(has_local), m_hash(h), m_rc(0) { // m_hash_alloc does not need to be a unique identifier. @@ -63,8 +65,8 @@ void expr_cell::dec_ref(expr & e, buffer & todelete) { } optional expr_cell::is_arrow() const { - // it is stored in bits 3-4 - unsigned r = (m_flags & (8+16)) >> 3; + // it is stored in bits 2-3 + unsigned r = (m_flags & (4+8)) >> 2; if (r == 0) { return optional(); } else if (r == 1) { @@ -76,25 +78,25 @@ optional expr_cell::is_arrow() const { } void expr_cell::set_is_arrow(bool flag) { - unsigned mask = flag ? 8 : 16; + unsigned mask = flag ? 4 : 8; m_flags |= mask; lean_assert(is_arrow() && *is_arrow() == flag); } // Expr variables expr_var::expr_var(unsigned idx): - expr_cell(expr_kind::Var, idx, false), + expr_cell(expr_kind::Var, idx, false, false), m_vidx(idx) {} // Expr constants expr_const::expr_const(name const & n, levels const & ls): - expr_cell(expr_kind::Constant, ::lean::hash(n.hash(), hash_levels(ls)), has_meta(ls)), + expr_cell(expr_kind::Constant, ::lean::hash(n.hash(), hash_levels(ls)), has_meta(ls), false), m_name(n), m_levels(ls) {} // Expr metavariables and local variables expr_mlocal::expr_mlocal(bool is_meta, name const & n, expr const & t): - expr_cell(is_meta ? expr_kind::Meta : expr_kind::Local, n.hash(), is_meta || t.has_metavar()), + expr_cell(is_meta ? expr_kind::Meta : expr_kind::Local, n.hash(), is_meta || t.has_metavar(), !is_meta || t.has_local()), m_name(n), m_type(t) {} void expr_mlocal::dealloc(buffer & todelete) { @@ -103,13 +105,15 @@ void expr_mlocal::dealloc(buffer & todelete) { } // Composite expressions -expr_composite::expr_composite(expr_kind k, unsigned h, bool has_mv, unsigned d): - expr_cell(k, h, has_mv), +expr_composite::expr_composite(expr_kind k, unsigned h, bool has_mv, bool has_local, unsigned d): + expr_cell(k, h, has_mv, has_local), m_depth(d) {} // Expr dependent pairs expr_dep_pair::expr_dep_pair(expr const & f, expr const & s, expr const & t): - expr_composite(expr_kind::Pair, ::lean::hash(f.hash(), s.hash()), f.has_metavar() || s.has_metavar() || t.has_metavar(), + expr_composite(expr_kind::Pair, ::lean::hash(f.hash(), s.hash()), + f.has_metavar() || s.has_metavar() || t.has_metavar(), + f.has_local() || s.has_local() || t.has_local(), std::max(get_depth(f), get_depth(s))+1), m_first(f), m_second(s), m_type(t) { } @@ -122,7 +126,7 @@ void expr_dep_pair::dealloc(buffer & todelete) { // Expr pair projection expr_proj::expr_proj(bool f, expr const & e): - expr_composite(f ? expr_kind::Fst : expr_kind::Snd, ::lean::hash(17, e.hash()), e.has_metavar(), get_depth(e)+1), + expr_composite(f ? expr_kind::Fst : expr_kind::Snd, ::lean::hash(17, e.hash()), e.has_metavar(), e.has_local(), get_depth(e)+1), m_expr(e) {} void expr_proj::dealloc(buffer & todelete) { dec_ref(m_expr, todelete); @@ -131,7 +135,9 @@ void expr_proj::dealloc(buffer & todelete) { // Expr applications expr_app::expr_app(expr const & fn, expr const & arg): - expr_composite(expr_kind::App, ::lean::hash(fn.hash(), arg.hash()), fn.has_metavar() || arg.has_metavar(), + expr_composite(expr_kind::App, ::lean::hash(fn.hash(), arg.hash()), + fn.has_metavar() || arg.has_metavar(), + fn.has_local() || arg.has_local(), std::max(get_depth(fn), get_depth(arg)) + 1), m_fn(fn), m_arg(arg) {} void expr_app::dealloc(buffer & todelete) { @@ -142,7 +148,9 @@ void expr_app::dealloc(buffer & todelete) { // Expr binders (Lambda, Pi and Sigma) expr_binder::expr_binder(expr_kind k, name const & n, expr const & t, expr const & b): - expr_composite(k, ::lean::hash(t.hash(), b.hash()), t.has_metavar() || b.has_metavar(), + expr_composite(k, ::lean::hash(t.hash(), b.hash()), + t.has_metavar() || b.has_metavar(), + t.has_local() || b.has_local(), std::max(get_depth(t), get_depth(b)) + 1), m_name(n), m_domain(t), @@ -157,14 +165,16 @@ void expr_binder::dealloc(buffer & todelete) { // Expr Sort expr_sort::expr_sort(level const & l): - expr_cell(expr_kind::Sort, ::lean::hash(l), has_meta(l)), + expr_cell(expr_kind::Sort, ::lean::hash(l), has_meta(l), false), m_level(l) { } expr_sort::~expr_sort() {} // Expr Let expr_let::expr_let(name const & n, expr const & t, expr const & v, expr const & b): - expr_composite(expr_kind::Let, ::lean::hash(v.hash(), b.hash()), t.has_metavar() || v.has_metavar() || b.has_metavar(), + expr_composite(expr_kind::Let, ::lean::hash(v.hash(), b.hash()), + t.has_metavar() || v.has_metavar() || b.has_metavar(), + t.has_local() || v.has_local() || b.has_local(), std::max({get_depth(t), get_depth(v), get_depth(b)}) + 1), m_name(n), m_type(t), @@ -215,7 +225,7 @@ static expr read_macro(deserializer & d) { } expr_macro::expr_macro(macro * m): - expr_cell(expr_kind::Macro, m->hash(), false), + expr_cell(expr_kind::Macro, m->hash(), false, false), m_macro(m) { m_macro->inc_ref(); } diff --git a/src/kernel/expr.h b/src/kernel/expr.h index c722b2f11..af87ddc4d 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -51,11 +51,13 @@ protected: // The bits of the following field mean: // 0 - term is maximally shared // 1 - term is closed - // 2 - term contains metavariables - // 3-4 - term is an arrow (0 - not initialized, 1 - is arrow, 2 - is not arrow) - atomic_ushort m_flags; - unsigned m_hash; // hash based on the structure of the expression (this is a good hash for structural equality) - unsigned m_hash_alloc; // hash based on 'time' of allocation (this is a good hash for pointer-based equality) + // 2-3 - term is an arrow (0 - not initialized, 1 - is arrow, 2 - is not arrow) + // Remark: we use atomic_uchar because these flags are computed lazily (i.e., after the expression is created) + atomic_uchar m_flags; + unsigned m_has_mv:1; // term contains metavariables + unsigned m_has_local:1; // term contains local constants + unsigned m_hash; // hash based on the structure of the expression (this is a good hash for structural equality) + unsigned m_hash_alloc; // hash based on 'time' of allocation (this is a good hash for pointer-based equality) MK_LEAN_RC(); // Declare m_rc counter void dealloc(); @@ -73,11 +75,12 @@ protected: friend class has_free_var_fn; static void dec_ref(expr & c, buffer & todelete); public: - expr_cell(expr_kind k, unsigned h, bool has_mv); + expr_cell(expr_kind k, unsigned h, bool has_mv, bool has_local); expr_kind kind() const { return static_cast(m_kind); } unsigned hash() const { return m_hash; } unsigned hash_alloc() const { return m_hash_alloc; } - bool has_metavar() const { return (m_flags & 4) != 0; } + bool has_metavar() const { return m_has_mv; } + bool has_local() const { return m_has_local; } }; class macro; @@ -115,6 +118,7 @@ public: unsigned hash() const { return m_ptr ? m_ptr->hash() : 23; } unsigned hash_alloc() const { return m_ptr ? m_ptr->hash_alloc() : 23; } bool has_metavar() const { return m_ptr->has_metavar(); } + bool has_local() const { return m_ptr->has_local(); } expr_cell * raw() const { return m_ptr; } @@ -194,7 +198,7 @@ class expr_composite : public expr_cell { unsigned m_depth; friend unsigned get_depth(expr const & e); public: - expr_composite(expr_kind k, unsigned h, bool has_mv, unsigned d); + expr_composite(expr_kind k, unsigned h, bool has_mv, bool has_local, unsigned d); }; /** \brief Applications */ @@ -510,6 +514,7 @@ inline expr const & mlocal_type(expr const & e) { return to_mlocal(e) inline bool is_constant(expr const & e, name const & n) { return is_constant(e) && const_name(e) == n; } inline bool has_metavar(expr const & e) { return e.has_metavar(); } +inline bool has_local(expr const & e) { return e.has_local(); } unsigned get_depth(expr const & e); // ======================================= diff --git a/src/tests/kernel/expr.cpp b/src/tests/kernel/expr.cpp index edeba03b0..1a6bc6c1b 100644 --- a/src/tests/kernel/expr.cpp +++ b/src/tests/kernel/expr.cpp @@ -303,6 +303,15 @@ static void tst15() { lean_assert(has_metavar(f(a, a, m))); lean_assert(has_metavar(f(a, m, a, a))); lean_assert(!has_metavar(f(a, a, a, a))); + lean_assert(!has_metavar(mk_fst(a))); + lean_assert(!has_metavar(mk_snd(a))); + lean_assert(has_metavar(mk_fst(m))); + lean_assert(has_metavar(mk_snd(m))); + lean_assert(!has_metavar(mk_pair(a, x, x))); + lean_assert(has_metavar(mk_pair(f(m), x, x))); + lean_assert(has_metavar(mk_pair(f(a), m, x))); + lean_assert(has_metavar(mk_pair(f(a), a, m))); + lean_assert(has_metavar(mk_pair(f(a), a, f(m)))); } static void check_copy(expr const & e) { @@ -338,6 +347,52 @@ static void tst17() { check_serializer(t); } +static void tst18() { + expr f = Const("f"); + expr x = Var(0); + expr a = Const("a"); + expr l = mk_local("m", Bool); + expr m = mk_metavar("m", Bool); + check_serializer(l); + lean_assert(!has_local(m)); + lean_assert(has_local(l)); + lean_assert(!has_local(f(m))); + lean_assert(has_local(f(l))); + lean_assert(!has_local(f(a))); + lean_assert(!has_local(f(x))); + lean_assert(!has_local(Pi({a, Type}, a))); + lean_assert(!has_local(Pi({a, m}, a))); + lean_assert(!has_local(Type)); + lean_assert(!has_local(Pi({a, Type}, a))); + lean_assert(has_local(Pi({a, Type}, l))); + lean_assert(!has_metavar(Pi({a, Type}, l))); + lean_assert(has_local(Pi({a, l}, a))); + lean_assert(has_local(Fun({a, Type}, l))); + lean_assert(has_local(Fun({a, l}, a))); + lean_assert(!has_local(Let(a, Type, Bool, a))); + lean_assert(has_local(mk_let(name("a"), l, f(x), f(f(x))))); + lean_assert(has_local(mk_let(name("a"), Type, f(l), f(f(x))))); + lean_assert(has_local(mk_let(name("a"), Type, f(x), f(f(l))))); + lean_assert(has_local(f(a, a, l))); + lean_assert(has_local(f(a, l, a, a))); + lean_assert(!has_local(f(a, a, a, a))); + lean_assert(!has_local(mk_fst(a))); + lean_assert(!has_local(mk_snd(a))); + lean_assert(has_local(mk_fst(l))); + lean_assert(has_local(mk_snd(l))); + lean_assert(!has_local(mk_fst(m))); + lean_assert(!has_local(mk_snd(m))); + lean_assert(!has_local(mk_pair(a, x, x))); + lean_assert(has_local(mk_pair(f(l), x, x))); + lean_assert(has_local(mk_pair(f(a), l, x))); + lean_assert(has_local(mk_pair(f(a), a, l))); + lean_assert(has_local(mk_pair(f(a), a, f(l)))); + lean_assert(!has_local(mk_pair(f(m), x, x))); + lean_assert(!has_local(mk_pair(f(a), m, x))); + lean_assert(!has_local(mk_pair(f(a), a, m))); + lean_assert(!has_local(mk_pair(f(a), a, f(m)))); +} + int main() { save_stack_info(); lean_assert(sizeof(expr) == sizeof(optional)); @@ -358,6 +413,7 @@ int main() { tst15(); tst16(); tst17(); + tst18(); std::cout << "sizeof(expr): " << sizeof(expr) << "\n"; std::cout << "sizeof(expr_app): " << sizeof(expr_app) << "\n"; std::cout << "sizeof(expr_cell): " << sizeof(expr_cell) << "\n";