From 5e5ab1429d48b73d12edfb67f6f3e232b090ff1b Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 3 Feb 2014 18:15:56 -0800 Subject: [PATCH] feat(frontends/lean): parse and pretty print sigma types This commit also fixes some bugs in the implementation of Sigma types. Signed-off-by: Leonardo de Moura --- src/emacs/lean-input.el | 2 + src/frontends/lean/notation.h | 3 +- src/frontends/lean/parser_cmds.cpp | 10 ++--- src/frontends/lean/parser_expr.cpp | 61 ++++++++++++++++++++---------- src/frontends/lean/parser_imp.h | 6 ++- src/frontends/lean/pp.cpp | 31 ++++++++++++--- src/frontends/lean/scanner.cpp | 20 +++++++++- src/frontends/lean/scanner.h | 4 +- src/kernel/expr.cpp | 4 ++ src/kernel/expr.h | 9 ++--- src/kernel/normalizer.cpp | 3 +- src/kernel/replace_visitor.cpp | 2 +- src/kernel/type_checker.cpp | 8 +++- src/kernel/update_expr.cpp | 17 +++++++-- src/kernel/update_expr.h | 8 +++- 15 files changed, 139 insertions(+), 49 deletions(-) diff --git a/src/emacs/lean-input.el b/src/emacs/lean-input.el index 2a3b92bbe..781b04277 100644 --- a/src/emacs/lean-input.el +++ b/src/emacs/lean-input.el @@ -519,6 +519,8 @@ order for the change to take effect." ("fun" . ("λ")) + ("X" . ("⨯")) + ;; Primes. ("'" . ,(lean-input-to-string-list "′″‴⁗")) diff --git a/src/frontends/lean/notation.h b/src/frontends/lean/notation.h index d057ff76b..9aa678028 100644 --- a/src/frontends/lean/notation.h +++ b/src/frontends/lean/notation.h @@ -7,6 +7,7 @@ Author: Leonardo de Moura #pragma once #include namespace lean { -constexpr unsigned g_arrow_precedence = 25; +constexpr unsigned g_arrow_precedence = 25; +constexpr unsigned g_cartesian_product_precedence = 25; constexpr unsigned g_app_precedence = std::numeric_limits::max(); } diff --git a/src/frontends/lean/parser_cmds.cpp b/src/frontends/lean/parser_cmds.cpp index 8224723c6..81e392eac 100644 --- a/src/frontends/lean/parser_cmds.cpp +++ b/src/frontends/lean/parser_cmds.cpp @@ -139,13 +139,13 @@ void parser_imp::parse_def_core(bool is_definition) { auto p = pos(); type_body = save(mk_placeholder(), p); } - pre_type = mk_abstraction(false, parameters, type_body); + pre_type = mk_abstraction(expr_kind::Pi, parameters, type_body); if (!is_definition && curr_is_period()) { - pre_val = mk_abstraction(true, parameters, mk_placeholder()); + pre_val = mk_abstraction(expr_kind::Lambda, parameters, mk_placeholder()); } else { check_assign_next("invalid definition, ':=' expected"); expr val_body = parse_expr(); - pre_val = mk_abstraction(true, parameters, val_body); + pre_val = mk_abstraction(expr_kind::Lambda, parameters, val_body); } } auto r = m_elaborator(id, pre_type, pre_val); @@ -210,7 +210,7 @@ void parser_imp::parse_variable_core(bool is_var) { parse_var_decl_parameters(parameters); check_colon_next("invalid variable/axiom declaration, ':' expected"); expr type_body = parse_expr(); - pre_type = mk_abstraction(false, parameters, type_body); + pre_type = mk_abstraction(expr_kind::Pi, parameters, type_body); } auto p = m_elaborator(pre_type); expr type = p.first; @@ -717,7 +717,7 @@ void parser_imp::parse_builtin() { parse_var_decl_parameters(parameters); check_colon_next("invalid builtin declaration, ':' expected"); expr type_body = parse_expr(); - auto p = m_elaborator(mk_abstraction(false, parameters, type_body)); + auto p = m_elaborator(mk_abstraction(expr_kind::Pi, parameters, type_body)); check_no_metavar(p, "invalid declaration, type still contains metavariables after elaboration"); type = p.first; } diff --git a/src/frontends/lean/parser_expr.cpp b/src/frontends/lean/parser_expr.cpp index 588897f21..b4179d9b6 100644 --- a/src/frontends/lean/parser_expr.cpp +++ b/src/frontends/lean/parser_expr.cpp @@ -594,6 +594,17 @@ expr parser_imp::parse_arrow(expr const & left) { return save(mk_arrow(left, right), p); } +/** \brief Parse expr '#' expr. */ +expr parser_imp::parse_cartesian_product(expr const & left) { + auto p = pos(); + next(); + mk_scope scope(*this); + register_binding(g_unused); + // The -1 is a trick to get right associativity in Pratt's parsers + expr right = parse_expr(g_cartesian_product_precedence-1); + return save(mk_cartesian_product(left, right), p); +} + /** \brief Parse '(' expr ')'. */ expr parser_imp::parse_lparen() { auto p = pos(); @@ -708,39 +719,46 @@ void parser_imp::parse_definition_parameters(parameter_buffer & result) { \brief Create a lambda/Pi abstraction, using the giving binders and body. */ -expr parser_imp::mk_abstraction(bool is_lambda, parameter_buffer const & parameters, expr const & body) { +expr parser_imp::mk_abstraction(expr_kind k, parameter_buffer const & parameters, expr const & body) { expr result = body; unsigned i = parameters.size(); while (i > 0) { --i; pos_info p = parameters[i].m_pos; - if (is_lambda) - result = save(mk_lambda(parameters[i].m_name, parameters[i].m_type, result), p); - else - result = save(mk_pi(parameters[i].m_name, parameters[i].m_type, result), p); + switch (k) { + case expr_kind::Lambda: result = save(mk_lambda(parameters[i].m_name, parameters[i].m_type, result), p); break; + case expr_kind::Pi: result = save(mk_pi(parameters[i].m_name, parameters[i].m_type, result), p); break; + case expr_kind::Sigma: result = save(mk_sigma(parameters[i].m_name, parameters[i].m_type, result), p); break; + default: lean_unreachable(); break; + } } return result; } /** \brief Parse lambda/Pi abstraction. */ -expr parser_imp::parse_abstraction(bool is_lambda) { +expr parser_imp::parse_abstraction(expr_kind k) { next(); mk_scope scope(*this); parameter_buffer parameters; parse_expr_parameters(parameters); check_comma_next("invalid abstraction, ',' expected"); expr result = parse_expr(); - return mk_abstraction(is_lambda, parameters, result); + return mk_abstraction(k, parameters, result); } /** \brief Parse lambda abstraction. */ expr parser_imp::parse_lambda() { - return parse_abstraction(true); + return parse_abstraction(expr_kind::Lambda); } /** \brief Parse Pi abstraction. */ expr parser_imp::parse_pi() { - return parse_abstraction(false); + return parse_abstraction(expr_kind::Pi); +} + +/** \brief Parse Sigma type */ +expr parser_imp::parse_sig() { + return parse_abstraction(expr_kind::Sigma); } /** \brief Parse exists */ @@ -825,6 +843,8 @@ expr parser_imp::parse_type(bool level_expected) { return save(Type(), p); } else if (curr() == scanner::token::Arrow) { return parse_arrow(save(Type(), p)); + } else if (curr() == scanner::token::CartesianProduct) { + return parse_cartesian_product(save(Type(), p)); } else { return save(mk_type(parse_level()), p); } @@ -936,6 +956,7 @@ expr parser_imp::parse_nud() { case scanner::token::LeftParen: return parse_lparen(); case scanner::token::Lambda: return parse_lambda(); case scanner::token::Pi: return parse_pi(); + case scanner::token::Sig: return parse_sig(); case scanner::token::Exists: return parse_exists(); case scanner::token::Let: return parse_let(); case scanner::token::IntVal: return parse_nat_int(); @@ -966,15 +987,16 @@ expr parser_imp::mk_app_left(expr const & left, expr const & arg) { */ expr parser_imp::parse_led(expr const & left) { switch (curr()) { - case scanner::token::Id: return parse_led_id(left); - case scanner::token::Arrow: return parse_arrow(left); - case scanner::token::LeftParen: return mk_app_left(left, parse_lparen()); - case scanner::token::IntVal: return mk_app_left(left, parse_nat_int()); - case scanner::token::DecimalVal: return mk_app_left(left, parse_decimal()); - case scanner::token::StringVal: return mk_app_left(left, parse_string()); - case scanner::token::Placeholder: return mk_app_left(left, parse_placeholder()); - case scanner::token::Type: return mk_app_left(left, parse_type(false)); - default: return left; + case scanner::token::Id: return parse_led_id(left); + case scanner::token::Arrow: return parse_arrow(left); + case scanner::token::CartesianProduct: return parse_cartesian_product(left); + case scanner::token::LeftParen: return mk_app_left(left, parse_lparen()); + case scanner::token::IntVal: return mk_app_left(left, parse_nat_int()); + case scanner::token::DecimalVal: return mk_app_left(left, parse_decimal()); + case scanner::token::StringVal: return mk_app_left(left, parse_string()); + case scanner::token::Placeholder: return mk_app_left(left, parse_placeholder()); + case scanner::token::Type: return mk_app_left(left, parse_type(false)); + default: return left; } } @@ -996,7 +1018,8 @@ unsigned parser_imp::curr_lbp() { return g_app_precedence; } } - case scanner::token::Arrow : return g_arrow_precedence; + case scanner::token::Arrow : return g_arrow_precedence; + case scanner::token::CartesianProduct: return g_cartesian_product_precedence; case scanner::token::LeftParen: case scanner::token::IntVal: case scanner::token::DecimalVal: case scanner::token::StringVal: case scanner::token::Type: case scanner::token::Placeholder: return g_app_precedence; diff --git a/src/frontends/lean/parser_imp.h b/src/frontends/lean/parser_imp.h index 941b174b9..1f269334c 100644 --- a/src/frontends/lean/parser_imp.h +++ b/src/frontends/lean/parser_imp.h @@ -292,14 +292,16 @@ private: expr parse_expr_macro(name const & id, pos_info const & p); expr parse_led_id(expr const & left); expr parse_arrow(expr const & left); + expr parse_cartesian_product(expr const & left); expr parse_lparen(); void parse_names(buffer> & result); void register_binding(name const & n); void parse_simple_parameters(parameter_buffer & result, bool implicit_decl, bool suppress_type); - expr mk_abstraction(bool is_lambda, parameter_buffer const & parameters, expr const & body); - expr parse_abstraction(bool is_lambda); + expr mk_abstraction(expr_kind k, parameter_buffer const & parameters, expr const & body); + expr parse_abstraction(expr_kind k); expr parse_lambda(); expr parse_pi(); + expr parse_sig(); expr parse_exists(); expr parse_let(); expr parse_type(bool level_expected); diff --git a/src/frontends/lean/pp.cpp b/src/frontends/lean/pp.cpp index b2692c635..c21c926d9 100644 --- a/src/frontends/lean/pp.cpp +++ b/src/frontends/lean/pp.cpp @@ -80,6 +80,9 @@ static format g_assign_fmt = highlight_keyword(format(":=")); static format g_geq_fmt = format("\u2265"); static format g_lift_fmt = highlight_keyword(format("lift")); static format g_inst_fmt = highlight_keyword(format("inst")); +static format g_sig_fmt = highlight_keyword(format("sig")); +static format g_cartesian_product_fmt = highlight_keyword(format("#")); +static format g_cartesian_product_n_fmt = highlight_keyword(format("\u2A2F")); static name g_pp_max_depth {"lean", "pp", "max_depth"}; static name g_pp_max_steps {"lean", "pp", "max_steps"}; @@ -471,7 +474,9 @@ class pp_fn { return op.get_precedence(); } else if (is_arrow(e)) { return g_arrow_precedence; - } else if (is_lambda(e) || is_pi(e) || is_let(e) || is_exists(e)) { + } else if (is_cartesian(e)) { + return g_cartesian_product_precedence; + } else if (is_lambda(e) || is_pi(e) || is_let(e) || is_exists(e) || is_sigma(e)) { return 0; } else { return g_app_precedence; @@ -821,6 +826,14 @@ class pp_fn { return pp_scoped_child(e, depth, g_arrow_precedence); } + result pp_cartesian_child(expr const & e, unsigned depth) { + return pp_scoped_child(e, depth, g_cartesian_product_precedence + 1); + } + + result pp_cartesian_body(expr const & e, unsigned depth) { + return pp_scoped_child(e, depth, g_cartesian_product_precedence); + } + template format pp_bnames(It const & begin, It const & end, bool use_line) { auto it = begin; @@ -899,6 +912,13 @@ class pp_fn { result p_rhs = pp_arrow_body(abst_body(e), depth); format r_format = group(format{p_lhs.first, space(), m_unicode ? g_arrow_n_fmt : g_arrow_fmt, line(), p_rhs.first}); return mk_result(r_format, p_lhs.second + p_rhs.second + 1); + } else if (is_cartesian(e) && !implicit_args) { + lean_assert(!T); + result p_lhs = pp_cartesian_child(abst_domain(e), depth); + result p_rhs = pp_cartesian_body(abst_body(e), depth); + format r_format = group(format{p_lhs.first, space(), m_unicode ? g_cartesian_product_n_fmt : g_cartesian_product_fmt, + line(), p_rhs.first}); + return mk_result(r_format, p_lhs.second + p_rhs.second + 1); } else { unsigned arrow_starting_at = get_arrow_starting_at(e); buffer> nested; @@ -909,11 +929,11 @@ class pp_fn { format head; if (!T && !implicit_args) { if (m_unicode) { - head = is_lambda(e) ? g_lambda_n_fmt : g_Pi_n_fmt; - head_indent = 2; + head = is_lambda(e) ? g_lambda_n_fmt : (is_pi(e) ? g_Pi_n_fmt : g_sig_fmt); + head_indent = is_sigma(e) ? 4 : 2; } else { - head = is_lambda(e) ? g_lambda_fmt : g_Pi_fmt; - head_indent = is_lambda(e) ? 4 : 3; + head = is_lambda(e) ? g_lambda_fmt : (is_pi(e) ? g_Pi_fmt : g_sig_fmt); + head_indent = is_pi(e) ? 3 : 4; } } format body_sep; @@ -1123,6 +1143,7 @@ class pp_fn { case expr_kind::Value: r = pp_value(e); break; case expr_kind::App: r = pp_app(e, depth); break; case expr_kind::Lambda: + case expr_kind::Sigma: case expr_kind::Pi: r = pp_abstraction(e, depth); break; case expr_kind::Type: r = pp_type(e); break; case expr_kind::Let: r = pp_let(e, depth); break; diff --git a/src/frontends/lean/scanner.cpp b/src/frontends/lean/scanner.cpp index e5f552051..e7921236a 100644 --- a/src/frontends/lean/scanner.cpp +++ b/src/frontends/lean/scanner.cpp @@ -29,6 +29,11 @@ static name g_placeholder_name("_"); static name g_have_name("have"); static name g_using_name("using"); static name g_by_name("by"); +static name g_sig_name("sig"); +static name g_tuple_name("tuple"); +static name g_proj_name("proj"); +static name g_cartesian_product_unicode("\u2A2F"); +static name g_cartesian_product("#"); static char g_normalized[256]; @@ -205,6 +210,12 @@ scanner::token scanner::read_a_symbol() { return token::Let; } else if (m_name_val == g_in_name) { return token::In; + } else if (m_name_val == g_sig_name) { + return token::Sig; + } else if (m_name_val == g_tuple_name) { + return token::Tuple; + } else if (m_name_val == g_proj_name) { + return token::Proj; } else if (m_name_val == g_placeholder_name) { return token::Placeholder; } else if (m_name_val == g_have_name) { @@ -236,6 +247,8 @@ scanner::token scanner::read_b_symbol(char prev) { m_name_val = name(m_buffer.c_str()); if (m_name_val == g_arrow_name) return token::Arrow; + else if (m_name_val == g_cartesian_product) + return token::CartesianProduct; else return token::Id; } @@ -255,6 +268,8 @@ scanner::token scanner::read_c_symbol() { m_name_val = name(m_buffer.c_str()); if (m_name_val == g_arrow_unicode) return token::Arrow; + if (m_name_val == g_cartesian_product_unicode) + return token::CartesianProduct; else if (m_name_val == g_lambda_unicode) return token::Lambda; else if (m_name_val == g_pi_unicode) @@ -442,12 +457,15 @@ std::ostream & operator<<(std::ostream & out, scanner::token const & t) { case scanner::token::IntVal: out << "Int"; break; case scanner::token::DecimalVal: out << "Dec"; break; case scanner::token::StringVal: out << "String"; break; - case scanner::token::Eq: out << "=="; break; case scanner::token::Assign: out << ":="; break; case scanner::token::Type: out << "Type"; break; + case scanner::token::Sig: out << "sig"; break; + case scanner::token::Proj: out << "proj"; break; + case scanner::token::Tuple: out << "tuple"; break; case scanner::token::Placeholder: out << "_"; break; case scanner::token::ScriptBlock: out << "Script"; break; case scanner::token::Have: out << "have"; break; + case scanner::token::CartesianProduct: out << "#"; break; case scanner::token::By: out << "by"; break; case scanner::token::Ellipsis: out << "..."; break; case scanner::token::Eof: out << "EOF"; break; diff --git a/src/frontends/lean/scanner.h b/src/frontends/lean/scanner.h index 404d7a6a1..d904d2e85 100644 --- a/src/frontends/lean/scanner.h +++ b/src/frontends/lean/scanner.h @@ -20,8 +20,8 @@ class scanner { public: enum class token { LeftParen, RightParen, LeftCurlyBracket, RightCurlyBracket, Colon, Comma, Period, Lambda, Pi, Arrow, - Let, In, Exists, Id, CommandId, IntVal, DecimalVal, StringVal, Eq, Assign, Type, Placeholder, - Have, By, ScriptBlock, Ellipsis, Eof + Sig, Tuple, Proj, Let, In, Exists, Id, CommandId, IntVal, DecimalVal, StringVal, Assign, Type, Placeholder, + Have, By, ScriptBlock, Ellipsis, CartesianProduct, Eof }; protected: int m_spos; // position in the current line of the stream diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index 560fa08a4..412c4c4f6 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -297,6 +297,10 @@ bool is_arrow(expr const & t) { } } +bool is_cartesian(expr const & t) { + return is_sigma(t) && !has_free_var(abst_body(t), 0); +} + unsigned get_depth(expr const & e) { switch (e.kind()) { case expr_kind::Var: case expr_kind::Constant: case expr_kind::Type: diff --git a/src/kernel/expr.h b/src/kernel/expr.h index 05e08d925..e1bb2105f 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -433,7 +433,7 @@ inline bool is_sigma(expr_cell * e) { return e->kind() == expr_kind::Sigma inline bool is_type(expr_cell * e) { return e->kind() == expr_kind::Type; } inline bool is_let(expr_cell * e) { return e->kind() == expr_kind::Let; } inline bool is_metavar(expr_cell * e) { return e->kind() == expr_kind::MetaVar; } -inline bool is_abstraction(expr_cell * e) { return is_lambda(e) || is_pi(e); } +inline bool is_abstraction(expr_cell * e) { return is_lambda(e) || is_pi(e) || is_sigma(e); } inline bool is_var(expr const & e) { return e.kind() == expr_kind::Var; } inline bool is_constant(expr const & e) { return e.kind() == expr_kind::Constant; } @@ -444,11 +444,12 @@ inline bool is_app(expr const & e) { return e.kind() == expr_kind::App; inline bool is_lambda(expr const & e) { return e.kind() == expr_kind::Lambda; } inline bool is_pi(expr const & e) { return e.kind() == expr_kind::Pi; } bool is_arrow(expr const & e); + bool is_cartesian(expr const & e); inline bool is_sigma(expr const & e) { return e.kind() == expr_kind::Sigma; } inline bool is_type(expr const & e) { return e.kind() == expr_kind::Type; } inline bool is_let(expr const & e) { return e.kind() == expr_kind::Let; } inline bool is_metavar(expr const & e) { return e.kind() == expr_kind::MetaVar; } -inline bool is_abstraction(expr const & e) { return is_lambda(e) || is_pi(e); } +inline bool is_abstraction(expr const & e) { return is_lambda(e) || is_pi(e) || is_sigma(e); } // ======================================= // ======================================= @@ -477,6 +478,7 @@ inline expr mk_pi(name const & n, expr const & t, expr const & e) { return expr( inline expr mk_sigma(name const & n, expr const & t, expr const & e) { return expr(new expr_sigma(n, t, e)); } inline bool is_default_arrow_var_name(name const & n) { return n == "a"; } inline expr mk_arrow(expr const & t, expr const & e) { return mk_pi(name("a"), t, e); } +inline expr mk_cartesian_product(expr const & t, expr const & e) { return mk_sigma(name("a"), t, e); } inline expr operator>>(expr const & t, expr const & e) { return mk_arrow(t, e); } inline expr mk_let(name const & n, optional const & t, expr const & v, expr const & e) { return expr(new expr_let(n, t, v, e)); } inline expr mk_let(name const & n, expr const & t, expr const & v, expr const & e) { return mk_let(n, some_expr(t), v, e); } @@ -692,9 +694,6 @@ template expr update_abst(expr const & e, F f) { return e; } } -inline expr update_abst(expr const & e, expr const & new_d, expr const & new_b) { - return update_abst(e, [&](expr const &, expr const &) { return mk_pair(new_d, new_b); }); -} template expr update_let(expr const & e, F f) { static_assert(std::is_same const &, expr const &, expr const &)>::type, std::tuple, expr, expr>>::value, diff --git a/src/kernel/normalizer.cpp b/src/kernel/normalizer.cpp index 3bc5b2108..c67882126 100644 --- a/src/kernel/normalizer.cpp +++ b/src/kernel/normalizer.cpp @@ -12,6 +12,7 @@ Author: Leonardo de Moura #include "util/buffer.h" #include "util/interrupt.h" #include "util/sexpr/options.h" +#include "kernel/update_expr.h" #include "kernel/normalizer.h" #include "kernel/expr.h" #include "kernel/expr_maps.h" @@ -158,7 +159,7 @@ class normalizer::imp { expr new_d = reify(normalize(abst_domain(e), s, k), k); m_cache.clear(); // make sure we do not reuse cached values from the previous call expr new_b = reify(normalize(abst_body(e), extend(s, mk_var(k)), k+1), k+1); - return update_abst(e, new_d, new_b); + return update_abstraction(e, new_d, new_b); } else { lean_assert(is_metavar(e)); // We use the following trick to reify a metavariable in the context of the value_stack s, and context ctx. diff --git a/src/kernel/replace_visitor.cpp b/src/kernel/replace_visitor.cpp index 94f1845ac..0f8d19ce0 100644 --- a/src/kernel/replace_visitor.cpp +++ b/src/kernel/replace_visitor.cpp @@ -51,7 +51,7 @@ expr replace_visitor::visit_pi(expr const & e, context const & ctx) { return visit_abst(e, ctx); } expr replace_visitor::visit_sigma(expr const & e, context const & ctx) { - lean_assert(is_pi(e)); + lean_assert(is_sigma(e)); return visit_abst(e, ctx); } expr replace_visitor::visit_let(expr const & e, context const & ctx) { diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 9dea3f011..380a4fe2c 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -318,8 +318,12 @@ class type_checker::imp { freset reset(m_cache); t2 = check_type(infer_type_core(abst_body(e), new_ctx), abst_body(e), new_ctx); } - if (is_pi(e) && is_bool(t2)) - return t2; + if (is_bool(t2)) { + if (is_pi(e)) + return t2; + else + t2 = Type(); + } if (is_type(t1) && is_type(t2)) { return save_result(e, mk_type(max(ty_level(t1), ty_level(t2))), shared); } else { diff --git a/src/kernel/update_expr.cpp b/src/kernel/update_expr.cpp index 67a4b67bc..89982daa2 100644 --- a/src/kernel/update_expr.cpp +++ b/src/kernel/update_expr.cpp @@ -53,11 +53,20 @@ expr update_pi(expr const & pi, expr const & d, expr const & b) { return mk_pi(abst_name(pi), d, b); } -expr update_abstraction(expr const & abst, expr const & d, expr const & b) { - if (is_lambda(abst)) - return update_lambda(abst, d, b); +expr update_sigma(expr const & sig, expr const & d, expr const & b) { + if (is_eqp(abst_domain(sig), d) && is_eqp(abst_body(sig), b)) + return sig; else - return update_pi(abst, d, b); + return mk_sigma(abst_name(sig), d, b); +} + +expr update_abstraction(expr const & abst, expr const & d, expr const & b) { + switch (abst.kind()) { + case expr_kind::Lambda: return update_lambda(abst, d, b); + case expr_kind::Pi: return update_pi(abst, d, b); + case expr_kind::Sigma: return update_sigma(abst, d, b); + default: lean_unreachable(); + } } expr update_let(expr const & let, optional const & t, expr const & v, expr const & b) { diff --git a/src/kernel/update_expr.h b/src/kernel/update_expr.h index b693f45b7..2c82e037e 100644 --- a/src/kernel/update_expr.h +++ b/src/kernel/update_expr.h @@ -23,11 +23,17 @@ inline expr update_app(expr const & app, buffer const & new_args) { return */ expr update_lambda(expr const & lambda, expr const & d, expr const & b); /** - \brief Return a pi expression based on \c pi with domain \c d and \c body b. + \brief Return a Pi expression based on \c pi with domain \c d and \c body b. \remark Return \c pi if the given domain and body are (pointer) equal to the ones in \c pi. */ expr update_pi(expr const & pi, expr const & d, expr const & b); +/** + \brief Return a Sigma expression based on \c sig with domain \c d and \c body b. + + \remark Return \c sig if the given domain and body are (pointer) equal to the ones in \c sig. +*/ +expr update_sigma(expr const & sig, expr const & d, expr const & b); /** \brief Return a lambda/pi expression based on \c abst with domain \c d and \c body b. */