diff --git a/src/kernel/converter.cpp b/src/kernel/converter.cpp index 8356cfa6d..79d8929d3 100644 --- a/src/kernel/converter.cpp +++ b/src/kernel/converter.cpp @@ -304,7 +304,7 @@ struct default_converter : public converter { expr var_s_type = instantiate(binding_domain(s), subst.size(), subst.data()); if (!is_def_eq(var_t_type, var_s_type, c, jst)) return false; - subst.push_back(mk_local(c.mk_fresh_name() + binding_name(s), var_s_type)); + subst.push_back(mk_local(c.mk_fresh_name(), binding_name(s), var_s_type)); t = binding_body(t); s = binding_body(s); } while (t.kind() == k && s.kind() == k); diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index fb2957212..d796d3532 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -114,6 +114,10 @@ void expr_mlocal::dealloc(buffer & todelete) { delete(this); } +expr_local::expr_local(name const & n, name const & pp_name, expr const & t): + expr_mlocal(false, n, t), + m_pp_name(pp_name) {} + // Composite expressions expr_composite::expr_composite(expr_kind k, unsigned h, bool has_mv, bool has_local, bool has_param_univ, unsigned d, unsigned fv_range): expr_cell(k, h, has_mv, has_local, has_param_univ), @@ -263,8 +267,8 @@ void expr_cell::dealloc() { switch (it->kind()) { case expr_kind::Var: delete static_cast(it); break; case expr_kind::Macro: static_cast(it)->dealloc(todo); break; - case expr_kind::Meta: - case expr_kind::Local: static_cast(it)->dealloc(todo); break; + case expr_kind::Meta: static_cast(it)->dealloc(todo); break; + case expr_kind::Local: static_cast(it)->dealloc(todo); break; case expr_kind::Constant: delete static_cast(it); break; case expr_kind::Sort: delete static_cast(it); break; case expr_kind::App: static_cast(it)->dealloc(todo); break; @@ -414,10 +418,12 @@ expr update_let(expr const & e, expr const & new_type, expr const & new_val, exp } expr update_mlocal(expr const & e, expr const & new_type) { - if (!is_eqp(mlocal_type(e), new_type)) - return copy_tag(e, mk_mlocal(is_metavar(e), mlocal_name(e), new_type)); - else + if (is_eqp(mlocal_type(e), new_type)) return e; + else if (is_metavar(e)) + return copy_tag(e, mk_metavar(mlocal_name(e), new_type)); + else + return copy_tag(e, mk_local(mlocal_name(e), local_pp_name(e), new_type)); } expr update_sort(expr const & e, level const & new_level) { @@ -484,7 +490,7 @@ expr copy(expr const & a) { case expr_kind::Pi: return mk_pi(binding_name(a), binding_domain(a), binding_body(a), binding_info(a)); case expr_kind::Let: return mk_let(let_name(a), let_type(a), let_value(a), let_body(a)); case expr_kind::Meta: return mk_metavar(mlocal_name(a), mlocal_type(a)); - case expr_kind::Local: return mk_local(mlocal_name(a), mlocal_type(a)); + case expr_kind::Local: return mk_local(mlocal_name(a), local_pp_name(a), mlocal_type(a)); } lean_unreachable(); // LCOV_EXCL_LINE } diff --git a/src/kernel/expr.h b/src/kernel/expr.h index e27b63d55..831b3e72f 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -126,7 +126,8 @@ public: friend expr mk_var(unsigned idx); friend expr mk_sort(level const & l); friend expr mk_constant(name const & n, levels const & ls); - friend expr mk_mlocal(bool is_meta, name const & n, expr const & t); + friend expr mk_metavar(name const & n, expr const & t); + friend expr mk_local(name const & n, name const & pp_n, expr const & t); friend expr mk_app(expr const & f, expr const & a); friend expr mk_pair(expr const & f, expr const & s, expr const & t); friend expr mk_proj(bool fst, expr const & p); @@ -197,6 +198,17 @@ public: expr const & get_type() const { return m_type; } }; +/** \brief expr_mlocal subclass for local constants. */ +class expr_local : public expr_mlocal { + // The name used in the binder that generate this local, + // it is only used for pretty printing. This field is ignored + // when comparing expressions. + name m_pp_name; +public: + expr_local(name const & n, name const & pp_name, expr const & t); + name const & get_pp_name() const { return m_pp_name; } +}; + /** \brief Composite expressions */ class expr_composite : public expr_cell { unsigned m_depth; @@ -415,9 +427,8 @@ inline expr mk_constant(name const & n, levels const & ls) { return expr(new exp inline expr mk_constant(name const & n) { return mk_constant(n, levels()); } inline expr Const(name const & n) { return mk_constant(n); } inline expr mk_macro(macro_definition const & m, unsigned num = 0, expr const * args = nullptr) { return expr(new expr_macro(m, num, args)); } -inline expr mk_mlocal(bool is_meta, name const & n, expr const & t) { return expr(new expr_mlocal(is_meta, n, t)); } -inline expr mk_metavar(name const & n, expr const & t) { return mk_mlocal(true, n, t); } -inline expr mk_local(name const & n, expr const & t) { return mk_mlocal(false, n, t); } +inline expr mk_metavar(name const & n, expr const & t) { return expr(new expr_mlocal(true, n, t)); } +inline expr mk_local(name const & n, name const & pp_n, expr const & t) { return expr(new expr_local(n, pp_n, t)); } inline expr mk_app(expr const & f, expr const & a) { return expr(new expr_app(f, a)); } expr mk_app(expr const & f, unsigned num_args, expr const * args); expr mk_app(unsigned num_args, expr const * args); @@ -493,7 +504,7 @@ inline expr_binding * to_binding(expr_cell * e) { lean_assert(is_binder(e inline expr_let * to_let(expr_cell * e) { lean_assert(is_let(e)); return static_cast(e); } inline expr_sort * to_sort(expr_cell * e) { lean_assert(is_sort(e)); return static_cast(e); } inline expr_mlocal * to_mlocal(expr_cell * e) { lean_assert(is_mlocal(e)); return static_cast(e); } -inline expr_mlocal * to_local(expr_cell * e) { lean_assert(is_local(e)); return static_cast(e); } +inline expr_local * to_local(expr_cell * e) { lean_assert(is_local(e)); return static_cast(e); } inline expr_mlocal * to_metavar(expr_cell * e) { lean_assert(is_metavar(e)); return static_cast(e); } inline expr_macro * to_macro(expr_cell * e) { lean_assert(is_macro(e)); return static_cast(e); } @@ -505,7 +516,7 @@ inline expr_let * to_let(expr const & e) { return to_let(e.raw() inline expr_sort * to_sort(expr const & e) { return to_sort(e.raw()); } inline expr_mlocal * to_mlocal(expr const & e) { return to_mlocal(e.raw()); } inline expr_mlocal * to_metavar(expr const & e) { return to_metavar(e.raw()); } -inline expr_mlocal * to_local(expr const & e) { return to_local(e.raw()); } +inline expr_local * to_local(expr const & e) { return to_local(e.raw()); } inline expr_macro * to_macro(expr const & e) { return to_macro(e.raw()); } // ======================================= @@ -561,6 +572,7 @@ inline expr const & let_type(expr const & e) { return to_let(e)-> inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); } inline name const & mlocal_name(expr const & e) { return to_mlocal(e)->get_name(); } inline expr const & mlocal_type(expr const & e) { return to_mlocal(e)->get_type(); } +inline name const & local_pp_name(expr const & e) { return to_local(e)->get_pp_name(); } inline bool is_constant(expr const & e, name const & n) { return is_constant(e) && const_name(e) == n; } inline bool has_metavar(expr const & e) { return e.has_metavar(); } diff --git a/src/kernel/formatter.cpp b/src/kernel/formatter.cpp index 56442011e..8b33416c3 100644 --- a/src/kernel/formatter.cpp +++ b/src/kernel/formatter.cpp @@ -99,7 +99,7 @@ struct print_expr_fn { out() << "?" << mlocal_name(a); break; case expr_kind::Local: - out() << "!" << mlocal_name(a); + out() << local_pp_name(a); break; case expr_kind::Var: { auto e = find(c, var_idx(a)); diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 0b906f38b..51d9397dd 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -85,7 +85,7 @@ struct type_checker::imp { It also returns the fresh local constant. */ std::pair open_binding_body(expr const & e) { - expr local = mk_local(m_gen.next() + binding_name(e), binding_domain(e)); + expr local = mk_local(m_gen.next(), binding_name(e), binding_domain(e)); return mk_pair(instantiate(binding_body(e), local), local); } diff --git a/src/library/deep_copy.cpp b/src/library/deep_copy.cpp index 408a62890..fbca3d52d 100644 --- a/src/library/deep_copy.cpp +++ b/src/library/deep_copy.cpp @@ -32,7 +32,7 @@ class deep_copy_fn { case expr_kind::Pi: r = mk_pi(binding_name(a), apply(binding_domain(a)), apply(binding_body(a))); break; case expr_kind::Let: r = mk_let(let_name(a), apply(let_type(a)), apply(let_value(a)), apply(let_body(a))); break; case expr_kind::Meta: r = mk_metavar(mlocal_name(a), apply(mlocal_type(a))); break; - case expr_kind::Local: r = mk_local(mlocal_name(a), apply(mlocal_type(a))); break; + case expr_kind::Local: r = mk_local(mlocal_name(a), local_pp_name(a), apply(mlocal_type(a))); break; } if (sh) m_cache.insert(std::make_pair(a.raw(), r)); diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index c5cf8281c..28f6328a2 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -352,7 +352,14 @@ static int expr_fun(lua_State * L) { return expr_abst(L); } static int expr_pi(lua_State * L) { return expr_abst(L); } static int expr_mk_sort(lua_State * L) { return push_expr(L, mk_sort(to_level(L, 1))); } static int expr_mk_metavar(lua_State * L) { return push_expr(L, mk_metavar(to_name_ext(L, 1), to_expr(L, 2))); } -static int expr_mk_local(lua_State * L) { return push_expr(L, mk_local(to_name_ext(L, 1), to_expr(L, 2))); } +static int expr_mk_local(lua_State * L) { + int nargs = lua_gettop(L); + name n = to_name_ext(L, 1); + if (nargs == 2) + return push_expr(L, mk_local(n, n, to_expr(L, 2))); + else + return push_expr(L, mk_local(n, to_name_ext(L, 2), to_expr(L, 3))); +} static int expr_get_kind(lua_State * L) { return push_integer(L, static_cast(to_expr(L, 1).kind())); } // t is a table of pairs {{a1, b1, c1}, ..., {ak, bk, ck}} diff --git a/src/library/kernel_serializer.cpp b/src/library/kernel_serializer.cpp index 11a088a16..53d6b275a 100644 --- a/src/library/kernel_serializer.cpp +++ b/src/library/kernel_serializer.cpp @@ -162,9 +162,12 @@ class expr_serializer : public object_serializer> A), Fun({{A, mk_Type()}, {x, A}}, x))); - expr c = mk_local("c", Bool); + expr c = mk_local("c", "c", Bool); expr id = Const("id"); type_checker checker(env3, name_generator("tmp")); lean_assert(checker.check(id(Bool)) == Bool >> Bool); @@ -90,8 +90,8 @@ static void tst2() { expr f97 = Const(name(base, 97)); expr f98 = Const(name(base, 98)); expr f3 = Const(name(base, 3)); - expr c1 = mk_local("c1", Bool); - expr c2 = mk_local("c2", Bool); + expr c1 = mk_local("c1", "c1", Bool); + expr c2 = mk_local("c2", "c2", Bool); expr id = Const("id"); std::cout << checker.whnf(f3(c1, c2)) << "\n"; lean_assert_eq(env.find(name(base, 98))->get_weight(), 98); diff --git a/src/tests/kernel/expr.cpp b/src/tests/kernel/expr.cpp index a2a8547db..eaec07023 100644 --- a/src/tests/kernel/expr.cpp +++ b/src/tests/kernel/expr.cpp @@ -340,7 +340,7 @@ static void tst18() { expr f = Const("f"); expr x = Var(0); expr a = Const("a"); - expr l = mk_local("m", Bool); + expr l = mk_local("m", "m", Bool); expr m = mk_metavar("m", Bool); check_serializer(l); lean_assert(!has_local(m));