diff --git a/src/builtin/kernel.lean b/src/builtin/kernel.lean index a0ef26c56..057dbb82d 100644 --- a/src/builtin/kernel.lean +++ b/src/builtin/kernel.lean @@ -12,8 +12,8 @@ definition TypeU := (Type U) variable Bool : Type -- Heterogeneous equality -variable heq {A B : (Type U)} (a : A) (b : B) : Bool -infixl 50 == : heq +variable heq2 {A B : (Type U)} (a : A) (b : B) : Bool +infixl 50 == : heq2 -- Reflexivity for heterogeneous equality axiom hrefl {A : (Type U)} (a : A) : a == a diff --git a/src/builtin/obj/Int.olean b/src/builtin/obj/Int.olean index 83995e037..e1a16bf66 100644 Binary files a/src/builtin/obj/Int.olean and b/src/builtin/obj/Int.olean differ diff --git a/src/builtin/obj/Nat.olean b/src/builtin/obj/Nat.olean index e351cb21e..f9391e528 100644 Binary files a/src/builtin/obj/Nat.olean and b/src/builtin/obj/Nat.olean differ diff --git a/src/builtin/obj/kernel.olean b/src/builtin/obj/kernel.olean index f3ac6fb14..ce394517b 100644 Binary files a/src/builtin/obj/kernel.olean and b/src/builtin/obj/kernel.olean differ diff --git a/src/builtin/obj/optional.olean b/src/builtin/obj/optional.olean index befd296dc..3b60f9f15 100644 Binary files a/src/builtin/obj/optional.olean and b/src/builtin/obj/optional.olean differ diff --git a/src/builtin/obj/subtype.olean b/src/builtin/obj/subtype.olean index a408f48c2..3a235d759 100644 Binary files a/src/builtin/obj/subtype.olean and b/src/builtin/obj/subtype.olean differ diff --git a/src/builtin/obj/sum.olean b/src/builtin/obj/sum.olean index 737fd28a9..cd3e5b3ef 100644 Binary files a/src/builtin/obj/sum.olean and b/src/builtin/obj/sum.olean differ diff --git a/src/frontends/lean/notation.h b/src/frontends/lean/notation.h index e87d12105..757dd568e 100644 --- a/src/frontends/lean/notation.h +++ b/src/frontends/lean/notation.h @@ -7,6 +7,7 @@ Author: Leonardo de Moura #pragma once #include namespace lean { +constexpr unsigned g_heq_precedence = 50; constexpr unsigned g_arrow_precedence = 25; constexpr unsigned g_cartesian_product_precedence = 30; constexpr unsigned g_app_precedence = std::numeric_limits::max(); diff --git a/src/frontends/lean/pp.cpp b/src/frontends/lean/pp.cpp index 1ac669993..3e00e5d40 100644 --- a/src/frontends/lean/pp.cpp +++ b/src/frontends/lean/pp.cpp @@ -84,6 +84,7 @@ static format g_geq_fmt = format("\u2265"); static format g_lift_fmt = highlight_keyword(format("lift")); static format g_inst_fmt = highlight_keyword(format("inst")); static format g_sig_fmt = highlight_keyword(format("sig")); +static format g_heq_fmt = highlight_keyword(format("==")); static format g_cartesian_product_fmt = highlight_keyword(format("#")); static format g_cartesian_product_n_fmt = highlight_keyword(format("\u2A2F")); @@ -271,6 +272,7 @@ class pp_fn { return false; case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Let: case expr_kind::Sigma: case expr_kind::Pair: case expr_kind::Proj: + case expr_kind::HEq: return false; } return false; @@ -482,6 +484,8 @@ class pp_fn { return g_cartesian_product_precedence; } else if (is_lambda(e) || is_pi(e) || is_let(e) || is_exists(e) || is_sigma(e) || is_pair(e)) { return 0; + } else if (is_heq(e)) { + return g_heq_precedence; } else { return g_app_precedence; } @@ -1163,6 +1167,13 @@ class pp_fn { return result(group(r_format), r_weight); } + result pp_heq(expr const & a, unsigned depth) { + result p_lhs = pp_child(heq_lhs(a), depth); + result p_rhs = pp_child(heq_rhs(a), depth); + format r_format = group(format{p_lhs.first, space(), g_heq_fmt, line(), p_rhs.first}); + return mk_result(r_format, p_lhs.second + p_rhs.second + 1); + } + result pp(expr const & e, unsigned depth, bool main = false) { check_system("pretty printer"); if (!is_atomic(e) && (m_num_steps > m_max_steps || depth > m_max_depth)) { @@ -1199,6 +1210,7 @@ class pp_fn { case expr_kind::Type: r = pp_type(e); break; case expr_kind::Let: r = pp_let(e, depth); break; case expr_kind::MetaVar: r = pp_metavar(e, depth); break; + case expr_kind::HEq: r = pp_heq(e, depth); break; case expr_kind::Pair: r = pp_tuple(e, depth); break; case expr_kind::Proj: r = pp_proj(e, depth); break; } diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index 412c4c4f6..9f4e4b038 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -22,6 +22,7 @@ namespace lean { static expr g_dummy(mk_var(0)); expr::expr():expr(g_dummy) {} +// Local context entries local_entry::local_entry(unsigned s, unsigned n):m_kind(local_entry_kind::Lift), m_s(s), m_n(n) {} local_entry::local_entry(unsigned s, expr const & v):m_kind(local_entry_kind::Inst), m_s(s), m_v(v) {} local_entry::~local_entry() {} @@ -87,10 +88,12 @@ void expr_cell::set_is_arrow(bool flag) { lean_assert(is_arrow() && *is_arrow() == flag); } +// Expr variables expr_var::expr_var(unsigned idx): expr_cell(expr_kind::Var, idx, false), m_vidx(idx) {} +// Expr constants expr_const::expr_const(name const & n, optional const & t): expr_cell(expr_kind::Constant, n.hash(), t && t->has_metavar()), m_name(n), @@ -99,6 +102,19 @@ void expr_const::dealloc(buffer & todelete) { dec_ref(m_type, todelete); delete(this); } + +// Expr heterogeneous equality +expr_heq::expr_heq(expr const & lhs, expr const & rhs): + expr_cell(expr_kind::HEq, ::lean::hash(lhs.hash(), rhs.hash()), lhs.has_metavar() || rhs.has_metavar()), + m_lhs(lhs), m_rhs(rhs), m_depth(std::max(get_depth(lhs), get_depth(rhs))+1) { +} +void expr_heq::dealloc(buffer & todelete) { + dec_ref(m_lhs, todelete); + dec_ref(m_rhs, todelete); + delete(this); +} + +// Expr dependent pairs expr_dep_pair::expr_dep_pair(expr const & f, expr const & s, expr const & t): expr_cell(expr_kind::Pair, ::lean::hash(f.hash(), s.hash()), f.has_metavar() || s.has_metavar() || t.has_metavar()), m_first(f), m_second(s), m_type(t), m_depth(std::max(get_depth(f), get_depth(s))+1) { @@ -109,6 +125,8 @@ void expr_dep_pair::dealloc(buffer & todelete) { dec_ref(m_type, todelete); delete(this); } + +// Expr pair projection expr_proj::expr_proj(bool f, expr const & e): expr_cell(expr_kind::Proj, ::lean::hash(17, e.hash()), e.has_metavar()), m_first(f), m_depth(get_depth(e)+1), m_expr(e) {} @@ -116,6 +134,8 @@ void expr_proj::dealloc(buffer & todelete) { dec_ref(m_expr, todelete); delete(this); } + +// Expr applications expr_app::expr_app(unsigned num_args, bool has_mv): expr_cell(expr_kind::App, 0, has_mv), m_num_args(num_args) { @@ -163,6 +183,8 @@ expr mk_app(unsigned n, expr const * as) { to_app(r)->m_depth = depth + 1; return r; } + +// Expr abstractions (and subclasses: Lambda, Pi and Sigma) expr_abstraction::expr_abstraction(expr_kind k, name const & n, expr const & t, expr const & b): expr_cell(k, ::lean::hash(t.hash(), b.hash()), t.has_metavar() || b.has_metavar()), m_name(n), @@ -175,14 +197,21 @@ void expr_abstraction::dealloc(buffer & todelete) { dec_ref(m_domain, todelete); delete(this); } + expr_lambda::expr_lambda(name const & n, expr const & t, expr const & e):expr_abstraction(expr_kind::Lambda, n, t, e) {} + expr_pi::expr_pi(name const & n, expr const & t, expr const & e):expr_abstraction(expr_kind::Pi, n, t, e) {} + expr_sigma::expr_sigma(name const & n, expr const & t, expr const & e):expr_abstraction(expr_kind::Sigma, n, t, e) {} + +// Expr Type expr_type::expr_type(level const & l): expr_cell(expr_kind::Type, l.hash(), false), m_level(l) { } expr_type::~expr_type() {} + +// Expr Let expr_let::expr_let(name const & n, optional const & t, expr const & v, expr const & b): expr_cell(expr_kind::Let, ::lean::hash(v.hash(), b.hash()), v.has_metavar() || b.has_metavar() || (t && t->has_metavar())), m_name(n), @@ -201,6 +230,8 @@ void expr_let::dealloc(buffer & todelete) { delete(this); } expr_let::~expr_let() {} + +// Expr Semantic attachment name value::get_unicode_name() const { return get_name(); } optional value::normalize(unsigned, expr const *) const { return none_expr(); } void value::display(std::ostream & out) const { out << get_name(); } @@ -244,10 +275,12 @@ static expr read_value(deserializer & d) { return it->second(d); } +// Expr Metavariable expr_metavar::expr_metavar(name const & n, local_context const & lctx): expr_cell(expr_kind::MetaVar, n.hash(), true), m_name(n), m_lctx(lctx) {} expr_metavar::~expr_metavar() {} + void expr_cell::dealloc() { try { buffer todo; @@ -268,6 +301,7 @@ void expr_cell::dealloc() { case expr_kind::Lambda: static_cast(it)->dealloc(todo); break; case expr_kind::Pi: static_cast(it)->dealloc(todo); break; case expr_kind::Sigma: static_cast(it)->dealloc(todo); break; + case expr_kind::HEq: static_cast(it)->dealloc(todo); break; case expr_kind::Let: static_cast(it)->dealloc(todo); break; } } @@ -306,6 +340,8 @@ unsigned get_depth(expr const & e) { case expr_kind::Var: case expr_kind::Constant: case expr_kind::Type: case expr_kind::Value: case expr_kind::MetaVar: return 1; + case expr_kind::HEq: + return to_heq(e)->m_depth; case expr_kind::Pair: return to_pair(e)->m_depth; case expr_kind::Proj: @@ -333,6 +369,7 @@ expr copy(expr const & a) { case expr_kind::Pi: return mk_pi(abst_name(a), abst_domain(a), abst_body(a)); case expr_kind::Sigma: return mk_sigma(abst_name(a), abst_domain(a), abst_body(a)); case expr_kind::Let: return mk_let(let_name(a), let_type(a), let_value(a), let_body(a)); + case expr_kind::HEq: return mk_heq(heq_lhs(a), heq_rhs(a)); case expr_kind::MetaVar: return mk_metavar(metavar_name(a), metavar_lctx(a)); } lean_unreachable(); // LCOV_EXCL_LINE @@ -376,7 +413,7 @@ constexpr bool is_small(expr_kind k) { return 0 <= static_cast(k) && stati static_assert(is_small(expr_kind::Var) && is_small(expr_kind::Constant) && is_small(expr_kind::Value) && is_small(expr_kind::App) && is_small(expr_kind::Pair) && is_small(expr_kind::Proj) && is_small(expr_kind::Lambda) && is_small(expr_kind::Pi) && is_small(expr_kind::Sigma) && is_small(expr_kind::Type) && - is_small(expr_kind::Let) && is_small(expr_kind::MetaVar), "expr_kind is too big"); + is_small(expr_kind::Let) && is_small(expr_kind::HEq) && is_small(expr_kind::MetaVar), "expr_kind is too big"); class expr_serializer : public object_serializer { typedef object_serializer super; @@ -415,6 +452,7 @@ class expr_serializer : public object_serializer const & t, expr const & v, expr const & e); + friend expr mk_heq(expr const & lhs, expr const & rhs); friend expr mk_metavar(name const & n, local_context const & ctx); friend bool is_eqp(expr const & a, expr const & b) { return a.m_ptr == b.m_ptr; } @@ -349,6 +351,22 @@ public: value const & get_value() const { return m_val; } }; + +/** \brief Heterogeneous equality */ +class expr_heq : public expr_cell { + expr m_lhs; + expr m_rhs; + unsigned m_depth; + friend expr_cell; + friend expr mk_heq(expr const & lhs, expr const & rhs); + void dealloc(buffer & todelete); + friend unsigned get_depth(expr const & e); +public: + expr_heq(expr const & lhs, expr const & rhs); + expr const & get_lhs() const { return m_lhs; } + expr const & get_rhs() const { return m_rhs; } +}; + /** \see local_entry */ @@ -432,6 +450,7 @@ inline bool is_pi(expr_cell * e) { return e->kind() == expr_kind::Pi; } inline bool is_sigma(expr_cell * e) { return e->kind() == expr_kind::Sigma; } inline bool is_type(expr_cell * e) { return e->kind() == expr_kind::Type; } inline bool is_let(expr_cell * e) { return e->kind() == expr_kind::Let; } +inline bool is_heq(expr_cell * e) { return e->kind() == expr_kind::HEq; } inline bool is_metavar(expr_cell * e) { return e->kind() == expr_kind::MetaVar; } inline bool is_abstraction(expr_cell * e) { return is_lambda(e) || is_pi(e) || is_sigma(e); } @@ -448,6 +467,7 @@ inline bool is_pi(expr const & e) { return e.kind() == expr_kind::Pi; } inline bool is_sigma(expr const & e) { return e.kind() == expr_kind::Sigma; } inline bool is_type(expr const & e) { return e.kind() == expr_kind::Type; } inline bool is_let(expr const & e) { return e.kind() == expr_kind::Let; } +inline bool is_heq(expr const & e) { return e.kind() == expr_kind::HEq; } inline bool is_metavar(expr const & e) { return e.kind() == expr_kind::MetaVar; } inline bool is_abstraction(expr const & e) { return is_lambda(e) || is_pi(e) || is_sigma(e); } // ======================================= @@ -487,6 +507,7 @@ inline expr mk_type(level const & l) { return expr(new expr_type(l)); } expr mk_type(); inline expr Type(level const & l) { return mk_type(l); } inline expr Type() { return mk_type(); } +inline expr mk_heq(expr const & lhs, expr const & rhs) { return expr(new expr_heq(lhs, rhs)); } inline expr mk_metavar(name const & n, local_context const & ctx = local_context()) { return expr(new expr_metavar(n, ctx)); } @@ -514,6 +535,7 @@ inline expr_pi * to_pi(expr_cell * e) { lean_assert(is_pi(e)); inline expr_sigma * to_sigma(expr_cell * e) { lean_assert(is_sigma(e)); return static_cast(e); } inline expr_type * to_type(expr_cell * e) { lean_assert(is_type(e)); return static_cast(e); } inline expr_let * to_let(expr_cell * e) { lean_assert(is_let(e)); return static_cast(e); } +inline expr_heq * to_heq(expr_cell * e) { lean_assert(is_heq(e)); return static_cast(e); } inline expr_metavar * to_metavar(expr_cell * e) { lean_assert(is_metavar(e)); return static_cast(e); } inline expr_var * to_var(expr const & e) { return to_var(e.raw()); } @@ -527,6 +549,7 @@ inline expr_pi * to_pi(expr const & e) { return to_pi(e.raw()) inline expr_sigma * to_sigma(expr const & e) { return to_sigma(e.raw()); } inline expr_let * to_let(expr const & e) { return to_let(e.raw()); } inline expr_type * to_type(expr const & e) { return to_type(e.raw()); } +inline expr_heq * to_heq(expr const & e) { return to_heq(e.raw()); } inline expr_metavar * to_metavar(expr const & e) { return to_metavar(e.raw()); } // ======================================= @@ -556,6 +579,8 @@ inline name const & let_name(expr_cell * e) { return to_let(e)->get inline expr const & let_value(expr_cell * e) { return to_let(e)->get_value(); } inline optional const & let_type(expr_cell * e) { return to_let(e)->get_type(); } inline expr const & let_body(expr_cell * e) { return to_let(e)->get_body(); } +inline expr const & heq_lhs(expr_cell * e) { return to_heq(e)->get_lhs(); } +inline expr const & heq_rhs(expr_cell * e) { return to_heq(e)->get_rhs(); } inline name const & metavar_name(expr_cell * e) { return to_metavar(e)->get_name(); } inline local_context const & metavar_lctx(expr_cell * e) { return to_metavar(e)->get_lctx(); } @@ -592,6 +617,8 @@ inline name const & let_name(expr const & e) { return to_let(e)->ge inline optional const & let_type(expr const & e) { return to_let(e)->get_type(); } inline expr const & let_value(expr const & e) { return to_let(e)->get_value(); } inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); } +inline expr const & heq_lhs(expr const & e) { return to_heq(e)->get_lhs(); } +inline expr const & heq_rhs(expr const & e) { return to_heq(e)->get_rhs(); } inline name const & metavar_name(expr const & e) { return to_metavar(e)->get_name(); } inline local_context const & metavar_lctx(expr const & e) { return to_metavar(e)->get_lctx(); } /** \brief Return the depth of the given expression */ @@ -765,6 +792,12 @@ inline expr update_proj(expr const & e, expr const & new_arg) { else return e; } +inline expr update_heq(expr const & e, expr const & new_lhs, expr const & new_rhs) { + if (!is_eqp(heq_lhs(e), new_lhs) || !is_eqp(heq_rhs(e), new_rhs)) + return mk_heq(new_lhs, new_rhs); + else + return e; +} // ======================================= diff --git a/src/kernel/expr_eq.h b/src/kernel/expr_eq.h index 729033c17..28fa25154 100644 --- a/src/kernel/expr_eq.h +++ b/src/kernel/expr_eq.h @@ -64,6 +64,7 @@ class expr_eq_fn { if (!apply(arg(a, i), arg(b, i))) return false; return true; + case expr_kind::HEq: return heq_lhs(a) == heq_lhs(b) && heq_rhs(a) == heq_rhs(b); case expr_kind::Pair: return pair_first(a) == pair_first(b) && pair_second(a) == pair_second(b) && pair_type(a) == pair_type(b); case expr_kind::Proj: return proj_first(a) == proj_first(b) && proj_arg(a) == proj_arg(b); case expr_kind::Sigma: diff --git a/src/kernel/for_each_fn.h b/src/kernel/for_each_fn.h index 6efdc4785..7cc51ad2b 100644 --- a/src/kernel/for_each_fn.h +++ b/src/kernel/for_each_fn.h @@ -85,6 +85,10 @@ class for_each_fn { } goto begin_loop; } + case expr_kind::HEq: + todo.emplace_back(heq_lhs(e), offset); + todo.emplace_back(heq_rhs(e), offset); + goto begin_loop; case expr_kind::Pair: todo.emplace_back(pair_first(e), offset); todo.emplace_back(pair_second(e), offset); diff --git a/src/kernel/free_vars.cpp b/src/kernel/free_vars.cpp index 54d2368e2..265376adf 100644 --- a/src/kernel/free_vars.cpp +++ b/src/kernel/free_vars.cpp @@ -42,7 +42,7 @@ protected: return var_idx(e) >= offset; case expr_kind::App: case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Let: case expr_kind::Sigma: - case expr_kind::Proj: case expr_kind::Pair: + case expr_kind::Proj: case expr_kind::Pair: case expr_kind::HEq: break; } @@ -86,6 +86,9 @@ protected: case expr_kind::Let: result = apply(let_type(e), offset) || apply(let_value(e), offset) || apply(let_body(e), offset + 1); break; + case expr_kind::HEq: + result = apply(heq_lhs(e), offset) || apply(heq_rhs(e), offset); + break; case expr_kind::Proj: result = apply(proj_arg(e), offset); break; @@ -178,6 +181,7 @@ class free_var_range_fn { case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Let: case expr_kind::Sigma: case expr_kind::Proj: case expr_kind::Pair: + case expr_kind::HEq: break; } @@ -215,6 +219,9 @@ class free_var_range_fn { case expr_kind::Let: result = std::max({apply(let_type(e)), apply(let_value(e)), dec(apply(let_body(e)))}); break; + case expr_kind::HEq: + result = std::max(apply(heq_lhs(e)), apply(heq_rhs(e))); + break; case expr_kind::Proj: result = apply(proj_arg(e)); break; @@ -301,6 +308,7 @@ protected: case expr_kind::App: case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Let: case expr_kind::Sigma: case expr_kind::Proj: case expr_kind::Pair: + case expr_kind::HEq: break; } @@ -345,6 +353,9 @@ protected: case expr_kind::Let: result = apply(let_type(e), offset) || apply(let_value(e), offset) || apply(let_body(e), offset + 1); break; + case expr_kind::HEq: + result = apply(heq_lhs(e), offset) || apply(heq_rhs(e), offset); + break; case expr_kind::Proj: result = apply(proj_arg(e), offset); break; diff --git a/src/kernel/kernel_decls.cpp b/src/kernel/kernel_decls.cpp index 6aeb89e18..41e17e089 100644 --- a/src/kernel/kernel_decls.cpp +++ b/src/kernel/kernel_decls.cpp @@ -8,7 +8,7 @@ Released under Apache 2.0 license as described in the file LICENSE. namespace lean { MK_CONSTANT(TypeU, name("TypeU")); MK_CONSTANT(Bool, name("Bool")); -MK_CONSTANT(heq_fn, name("heq")); +MK_CONSTANT(heq2_fn, name("heq2")); MK_CONSTANT(hrefl_fn, name("hrefl")); MK_CONSTANT(eq_fn, name("eq")); MK_CONSTANT(refl_fn, name("refl")); diff --git a/src/kernel/kernel_decls.h b/src/kernel/kernel_decls.h index 34768f45a..eddc52f67 100644 --- a/src/kernel/kernel_decls.h +++ b/src/kernel/kernel_decls.h @@ -9,10 +9,10 @@ expr mk_TypeU(); bool is_TypeU(expr const & e); expr mk_Bool(); bool is_Bool(expr const & e); -expr mk_heq_fn(); -bool is_heq_fn(expr const & e); -inline bool is_heq(expr const & e) { return is_app(e) && is_heq_fn(arg(e, 0)) && num_args(e) == 5; } -inline expr mk_heq(expr const & e1, expr const & e2, expr const & e3, expr const & e4) { return mk_app({mk_heq_fn(), e1, e2, e3, e4}); } +expr mk_heq2_fn(); +bool is_heq2_fn(expr const & e); +inline bool is_heq2(expr const & e) { return is_app(e) && is_heq2_fn(arg(e, 0)) && num_args(e) == 5; } +inline expr mk_heq2(expr const & e1, expr const & e2, expr const & e3, expr const & e4) { return mk_app({mk_heq2_fn(), e1, e2, e3, e4}); } expr mk_hrefl_fn(); bool is_hrefl_fn(expr const & e); inline expr mk_hrefl_th(expr const & e1, expr const & e2) { return mk_app({mk_hrefl_fn(), e1, e2}); } diff --git a/src/kernel/max_sharing.cpp b/src/kernel/max_sharing.cpp index b3e2a7dee..5d9fe1cc0 100644 --- a/src/kernel/max_sharing.cpp +++ b/src/kernel/max_sharing.cpp @@ -52,6 +52,9 @@ struct max_sharing_fn::imp { case expr_kind::Var: case expr_kind::Type: case expr_kind::Value: res = a; break; + case expr_kind::HEq: + res = update_heq(a, apply(heq_lhs(a)), apply(heq_rhs(a))); + break; case expr_kind::Pair: res = update_pair(a, [=](expr const & f, expr const & s, expr const & t) { return std::make_tuple(apply(f), apply(s), apply(t)); diff --git a/src/kernel/normalizer.cpp b/src/kernel/normalizer.cpp index c67882126..fea5e8310 100644 --- a/src/kernel/normalizer.cpp +++ b/src/kernel/normalizer.cpp @@ -250,6 +250,12 @@ class normalizer::imp { } break; } + case expr_kind::HEq: { + expr new_lhs = normalize(heq_lhs(a), s, k); + expr new_rhs = normalize(heq_rhs(a), s, k); + r = update_heq(a, new_lhs, new_rhs); + break; + } case expr_kind::Pair: { expr new_first = normalize(pair_first(a), s, k); expr new_second = normalize(pair_second(a), s, k); diff --git a/src/kernel/replace_fn.h b/src/kernel/replace_fn.h index 48202f342..616e341be 100644 --- a/src/kernel/replace_fn.h +++ b/src/kernel/replace_fn.h @@ -144,6 +144,14 @@ public: switch (e.kind()) { case expr_kind::Constant: case expr_kind::Type: case expr_kind::Value: case expr_kind::Var: case expr_kind::MetaVar: lean_unreachable(); // LCOV_EXCL_LINE + case expr_kind::HEq: + if (check_index(f, 0) && !visit(heq_lhs(e), offset)) + goto begin_loop; + if (check_index(f, 1) && !visit(heq_rhs(e), offset)) + goto begin_loop; + r = update_heq(e, rs(-2), rs(-1)); + pop_rs(2); + break; case expr_kind::Pair: if (check_index(f, 0) && !visit(pair_first(e), offset)) goto begin_loop; diff --git a/src/kernel/replace_visitor.cpp b/src/kernel/replace_visitor.cpp index 0f8d19ce0..b912d2bdb 100644 --- a/src/kernel/replace_visitor.cpp +++ b/src/kernel/replace_visitor.cpp @@ -31,6 +31,10 @@ expr replace_visitor::visit_app(expr const & e, context const & ctx) { lean_assert(is_app(e)); return update_app(e, [&](expr const & c) { return visit(c, ctx); }); } +expr replace_visitor::visit_heq(expr const & e, context const & ctx) { + lean_assert(is_heq(e)); + return update_heq(e, visit(heq_lhs(e), ctx), visit(heq_rhs(e), ctx)); +} expr replace_visitor::visit_abst(expr const & e, context const & ctx) { lean_assert(is_abstraction(e)); return update_abst(e, [&](expr const & t, expr const & b) { @@ -87,6 +91,7 @@ expr replace_visitor::visit(expr const & e, context const & ctx) { case expr_kind::Constant: return save_result(e, visit_constant(e, ctx), shared); case expr_kind::Var: return save_result(e, visit_var(e, ctx), shared); case expr_kind::MetaVar: return save_result(e, visit_metavar(e, ctx), shared); + case expr_kind::HEq: return save_result(e, visit_heq(e, ctx), shared); case expr_kind::Pair: return save_result(e, visit_pair(e, ctx), shared); case expr_kind::Proj: return save_result(e, visit_proj(e, ctx), shared); case expr_kind::App: return save_result(e, visit_app(e, ctx), shared); diff --git a/src/kernel/replace_visitor.h b/src/kernel/replace_visitor.h index d13b4ead2..735bf71ee 100644 --- a/src/kernel/replace_visitor.h +++ b/src/kernel/replace_visitor.h @@ -30,6 +30,7 @@ protected: virtual expr visit_constant(expr const &, context const &); virtual expr visit_var(expr const &, context const &); virtual expr visit_metavar(expr const &, context const &); + virtual expr visit_heq(expr const &, context const &); virtual expr visit_pair(expr const &, context const &); virtual expr visit_proj(expr const &, context const &); virtual expr visit_app(expr const &, context const &); diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 9a5b7746a..60fdffdbf 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -210,6 +210,10 @@ class type_checker::imp { } case expr_kind::Type: return mk_type(ty_level(e) + 1); + case expr_kind::HEq: + if (m_infer_only) + return Bool; + break; case expr_kind::App: case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Let: case expr_kind::Sigma: case expr_kind::Proj: @@ -225,6 +229,7 @@ class type_checker::imp { return it->second; } + expr r; switch (e.kind()) { case expr_kind::MetaVar: case expr_kind::Constant: case expr_kind::Type: case expr_kind::Value: lean_unreachable(); // LCOV_EXCL_LINE; @@ -238,13 +243,14 @@ class type_checker::imp { optional const & b = def.get_body(); lean_assert(b); expr t = infer_type_core(*b, def_ctx); - return save_result(e, lift_free_vars(t, var_idx(e) + 1), shared); + r = lift_free_vars(t, var_idx(e) + 1); + break; } case expr_kind::App: if (m_infer_only) { expr const & f = arg(e, 0); expr f_t = infer_type_core(f, ctx); - return save_result(e, get_range(f_t, e, ctx), shared); + r = get_range(f_t, e, ctx); } else { unsigned num = num_args(e); lean_assert(num >= 2); @@ -263,14 +269,23 @@ class type_checker::imp { throw app_type_mismatch_exception(env(), ctx, e, i, arg_types.size(), arg_types.data()); f_t = pi_body_at(f_t, c); i++; - if (i == num) - return save_result(e, f_t, shared); + if (i == num) { + r = f_t; + break; + } f_t = check_pi(f_t, e, ctx); } } + break; + case expr_kind::HEq: + lean_assert(!m_infer_only); + infer_type_core(heq_lhs(e), ctx); + infer_type_core(heq_rhs(e), ctx); + r = Bool; + break; case expr_kind::Pair: if (m_infer_only) { - return pair_type(e); + r = pair_type(e); } else { expr const & t = pair_type(e); expr sig = check_sigma(t, t, ctx); @@ -284,19 +299,21 @@ class type_checker::imp { if (!is_convertible(s_t, expected, ctx, mk_snd_justification)) { throw pair_type_mismatch_exception(env(), ctx, e, false, s_t, sig); } - return sig; + r = sig; } + break; case expr_kind::Proj: { expr t = check_sigma(infer_type_core(proj_arg(e), ctx), e, ctx); if (proj_first(e)) { - return abst_domain(t); + r = abst_domain(t); } else { expr const & b = abst_body(t); if (closed(b)) - return b; + r = b; else - return instantiate(b, mk_proj1(proj_arg(e))); + r = instantiate(b, mk_proj1(proj_arg(e))); } + break; } case expr_kind::Lambda: if (!m_infer_only) { @@ -305,10 +322,9 @@ class type_checker::imp { } { freset reset(m_cache); - return save_result(e, - mk_pi(abst_name(e), abst_domain(e), infer_type_core(abst_body(e), extend(ctx, abst_name(e), abst_domain(e)))), - shared); + r = mk_pi(abst_name(e), abst_domain(e), infer_type_core(abst_body(e), extend(ctx, abst_name(e), abst_domain(e)))); } + break; case expr_kind::Sigma: case expr_kind::Pi: { expr t1 = check_type(infer_type_core(abst_domain(e), ctx), abst_domain(e), ctx); if (is_bool(t1)) @@ -320,20 +336,22 @@ class type_checker::imp { t2 = check_type(infer_type_core(abst_body(e), new_ctx), abst_body(e), new_ctx); } if (is_bool(t2)) { - if (is_pi(e)) - return t2; - else + if (is_pi(e)) { + r = Bool; + break; + } else { t2 = Type(); + } } if (is_type(t1) && is_type(t2)) { - return save_result(e, mk_type(max(ty_level(t1), ty_level(t2))), shared); + r = mk_type(max(ty_level(t1), ty_level(t2))); } else { lean_assert(m_uc); justification jst = mk_max_type_justification(ctx, e); - expr r = m_menv->mk_metavar(ctx); + r = m_menv->mk_metavar(ctx); m_uc->push_back(mk_max_constraint(new_ctx, lift_free_vars(t1, 0, 1), t2, r, jst)); - return save_result(e, r, shared); } + break; } case expr_kind::Let: { optional lt; @@ -356,10 +374,11 @@ class type_checker::imp { { freset reset(m_cache); expr t = infer_type_core(let_body(e), extend(ctx, let_name(e), lt, let_value(e))); - return save_result(e, instantiate(t, let_value(e)), shared); + r = instantiate(t, let_value(e)); } + break; }} - lean_unreachable(); // LCOV_EXCL_LINE + return save_result(e, r, shared); } bool is_convertible_core(expr const & given, expr const & expected) { diff --git a/src/library/deep_copy.cpp b/src/library/deep_copy.cpp index 5e248feb8..87d32abaa 100644 --- a/src/library/deep_copy.cpp +++ b/src/library/deep_copy.cpp @@ -43,6 +43,7 @@ class deep_copy_fn { new_args.push_back(apply(old_arg)); return save_result(a, mk_app(new_args), sh); } + case expr_kind::HEq: return save_result(a, mk_heq(apply(heq_lhs(a)), apply(heq_rhs(a))), sh); case expr_kind::Pair: return save_result(a, mk_pair(apply(pair_first(a)), apply(pair_second(a)), apply(pair_type(a))), sh); case expr_kind::Proj: return save_result(a, mk_proj(proj_first(a), apply(proj_arg(a))), sh); case expr_kind::Lambda: return save_result(a, mk_lambda(abst_name(a), apply(abst_domain(a)), apply(abst_body(a))), sh); diff --git a/src/library/elaborator/elaborator.cpp b/src/library/elaborator/elaborator.cpp index 6e57a7278..c080eedb1 100644 --- a/src/library/elaborator/elaborator.cpp +++ b/src/library/elaborator/elaborator.cpp @@ -1675,6 +1675,12 @@ class elaborator::imp { m_conflict = mk_failure_justification(c); return false; } + case expr_kind::HEq: { + justification new_jst(new destruct_justification(c)); + push_new_eq_constraint(ctx, heq_lhs(a), heq_lhs(b), new_jst); + push_new_eq_constraint(ctx, heq_rhs(a), heq_rhs(b), new_jst); + return true; + } case expr_kind::Proj: if (proj_first(a) != proj_first(b)) { m_conflict = mk_failure_justification(c); diff --git a/src/library/expr_lt.cpp b/src/library/expr_lt.cpp index 130b76c03..b44da29d0 100644 --- a/src/library/expr_lt.cpp +++ b/src/library/expr_lt.cpp @@ -41,6 +41,11 @@ bool is_lt(expr const & a, expr const & b, bool use_hash) { return is_lt(arg(a, i), arg(b, i), use_hash); } lean_unreachable(); // LCOV_EXCL_LINE + case expr_kind::HEq: + if (heq_lhs(a) != heq_lhs(b)) + return is_lt(heq_lhs(a), heq_lhs(b), use_hash); + else + return is_lt(heq_rhs(a), heq_rhs(b), use_hash); case expr_kind::Pair: if (pair_first(a) != pair_first(b)) return is_lt(pair_first(a), pair_first(b), use_hash); diff --git a/src/library/fo_unify.cpp b/src/library/fo_unify.cpp index c65a3ab8e..83b2a67fe 100644 --- a/src/library/fo_unify.cpp +++ b/src/library/fo_unify.cpp @@ -62,6 +62,10 @@ optional fo_unify(expr e1, expr e2) { } } break; + case expr_kind::HEq: + todo.emplace_back(heq_lhs(e1), heq_lhs(e2)); + todo.emplace_back(heq_rhs(e1), heq_rhs(e2)); + break; case expr_kind::Proj: if (proj_first(e1) != proj_first(e2)) return optional(); diff --git a/src/library/hop_match.cpp b/src/library/hop_match.cpp index 369210893..ccde9e194 100644 --- a/src/library/hop_match.cpp +++ b/src/library/hop_match.cpp @@ -306,6 +306,10 @@ class hop_match_fn { } case expr_kind::Proj: return proj_first(p) == proj_first(t) && match(proj_arg(p), proj_arg(t), ctx, ctx_size); + case expr_kind::HEq: + return + match(heq_lhs(p), heq_lhs(t), ctx, ctx_size) && + match(heq_rhs(p), heq_rhs(t), ctx, ctx_size); case expr_kind::Pair: return match(pair_first(p), pair_first(t), ctx, ctx_size) && diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index 8da1624cc..145f0eb96 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -495,6 +495,7 @@ static int expr_fields(lua_State * L) { case expr_kind::Type: return push_level(L, ty_level(e)); case expr_kind::Value: return to_value(e).push_lua(L); case expr_kind::App: lua_pushinteger(L, num_args(e)); expr_args(L); return 2; + case expr_kind::HEq: push_expr(L, heq_lhs(e)); push_expr(L, heq_rhs(e)); return 3; case expr_kind::Pair: push_expr(L, pair_first(e)); push_expr(L, pair_second(e)); push_expr(L, pair_type(e)); return 3; case expr_kind::Proj: lua_pushboolean(L, proj_first(e)); push_expr(L, proj_arg(e)); return 2; case expr_kind::Lambda: @@ -728,10 +729,14 @@ static void open_expr(lua_State * L) { SET_ENUM("Constant", expr_kind::Constant); SET_ENUM("Type", expr_kind::Type); SET_ENUM("Value", expr_kind::Value); + SET_ENUM("Pair", expr_kind::Pair); + SET_ENUM("Proj", expr_kind::Proj); SET_ENUM("App", expr_kind::App); + SET_ENUM("Sigma", expr_kind::Sigma); SET_ENUM("Lambda", expr_kind::Lambda); SET_ENUM("Pi", expr_kind::Pi); SET_ENUM("Let", expr_kind::Let); + SET_ENUM("HEq", expr_kind::HEq); SET_ENUM("MetaVar", expr_kind::MetaVar); lua_setglobal(L, "expr_kind"); } diff --git a/src/library/printer.cpp b/src/library/printer.cpp index 5d7167303..ab6cc2ac4 100644 --- a/src/library/printer.cpp +++ b/src/library/printer.cpp @@ -19,7 +19,7 @@ bool is_atomic(expr const & e) { return true; case expr_kind::App: case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Let: case expr_kind::Sigma: case expr_kind::Proj: - case expr_kind::Pair: + case expr_kind::Pair: case expr_kind::HEq: return false; } return false; @@ -118,6 +118,11 @@ struct print_expr_fn { case expr_kind::App: print_app(a, c); break; + case expr_kind::HEq: + print_child(heq_lhs(a), c); + out() << " == "; + print_child(heq_rhs(a), c); + break; case expr_kind::Pair: print_pair(a, c); break; diff --git a/src/library/rewriter/fo_match.cpp b/src/library/rewriter/fo_match.cpp index 424952da8..ca86a8d72 100644 --- a/src/library/rewriter/fo_match.cpp +++ b/src/library/rewriter/fo_match.cpp @@ -173,7 +173,7 @@ bool fo_match::match_main(expr const & p, expr const & t, unsigned o, subst_map return match_let(p, t, o, s); case expr_kind::MetaVar: return match_metavar(p, t, o, s); - case expr_kind::Proj: case expr_kind::Pair: case expr_kind::Sigma: + case expr_kind::Proj: case expr_kind::Pair: case expr_kind::Sigma: case expr_kind::HEq: // TODO(Leo): break; } diff --git a/src/library/rewriter/rewriter.h b/src/library/rewriter/rewriter.h index 9e980fb87..8ff027d6d 100644 --- a/src/library/rewriter/rewriter.h +++ b/src/library/rewriter/rewriter.h @@ -384,7 +384,7 @@ class apply_rewriter_fn { } } break; - case expr_kind::Proj: case expr_kind::Pair: case expr_kind::Sigma: + case expr_kind::HEq: case expr_kind::Proj: case expr_kind::Pair: case expr_kind::Sigma: // TODO(Leo): break; case expr_kind::Lambda: { diff --git a/src/library/simplifier/ceq.cpp b/src/library/simplifier/ceq.cpp index 69065a961..1ff223397 100644 --- a/src/library/simplifier/ceq.cpp +++ b/src/library/simplifier/ceq.cpp @@ -209,6 +209,10 @@ static bool is_permutation(expr const & lhs, expr const & rhs, unsigned offset, } else { return lhs == rhs; // free variable } + case expr_kind::HEq: + return + is_permutation(heq_lhs(lhs), heq_lhs(rhs), offset, p) && + is_permutation(heq_rhs(lhs), heq_rhs(rhs), offset, p); case expr_kind::Proj: return proj_first(lhs) == proj_first(rhs) && is_permutation(proj_arg(lhs), proj_arg(rhs), offset, p); case expr_kind::Pair: diff --git a/src/library/simplifier/simplifier.cpp b/src/library/simplifier/simplifier.cpp index b190aa9c4..9d8628cf0 100644 --- a/src/library/simplifier/simplifier.cpp +++ b/src/library/simplifier/simplifier.cpp @@ -193,6 +193,14 @@ class simplifier_cell::imp { } }; + static bool is_heq(expr const & e) { + return is_heq2(e); + } + + static expr mk_heq(expr const & A, expr const & B, expr const & a, expr const & b) { + return mk_heq2(A, B, a, b); + } + static expr mk_lambda(name const & n, expr const & d, expr const & b) { return ::lean::mk_lambda(n, d, b); } diff --git a/src/tests/kernel/expr.cpp b/src/tests/kernel/expr.cpp index f56bb8cdc..4ed7b20cb 100644 --- a/src/tests/kernel/expr.cpp +++ b/src/tests/kernel/expr.cpp @@ -84,6 +84,8 @@ static unsigned depth2(expr const & e) { return depth2(proj_arg(e)) + 1; case expr_kind::Pair: return std::max(depth2(pair_first(e)), depth2(pair_second(e))) + 1; + case expr_kind::HEq: + return std::max(depth2(heq_lhs(e)), depth2(heq_rhs(e))) + 1; } return 0; } @@ -145,6 +147,8 @@ static unsigned count_core(expr const & a, expr_set & s) { return count_core(proj_arg(a), s) + 1; case expr_kind::Pair: return count_core(pair_first(a), s) + count_core(pair_second(a), s) + count_core(pair_type(a), s) + 1; + case expr_kind::HEq: + return count_core(heq_lhs(a), s) + count_core(heq_rhs(a), s); } return 0; } diff --git a/src/tests/kernel/normalizer.cpp b/src/tests/kernel/normalizer.cpp index 13cfde2ac..6f2e357ab 100644 --- a/src/tests/kernel/normalizer.cpp +++ b/src/tests/kernel/normalizer.cpp @@ -85,6 +85,8 @@ unsigned count_core(expr const & a, expr_set & s) { return count_core(proj_arg(a), s) + 1; case expr_kind::Pair: return count_core(pair_first(a), s) + count_core(pair_second(a), s) + count_core(pair_type(a), s) + 1; + case expr_kind::HEq: + return count_core(heq_lhs(a), s) + count_core(heq_rhs(a), s); } return 0; }