diff --git a/src/frontends/lean/builtin_exprs.cpp b/src/frontends/lean/builtin_exprs.cpp index 2d70d74ef..d3b943f15 100644 --- a/src/frontends/lean/builtin_exprs.cpp +++ b/src/frontends/lean/builtin_exprs.cpp @@ -29,6 +29,7 @@ Author: Leonardo de Moura #include "frontends/lean/tokens.h" #include "frontends/lean/info_tactic.h" #include "frontends/lean/info_annotation.h" +#include "frontends/lean/structure_cmd.h" namespace lean { namespace notation { @@ -496,6 +497,7 @@ parse_table init_nud_table() { r = r.add({transition("proof", mk_ext_action(parse_proof_qed))}, x0); r = r.add({transition("sorry", mk_ext_action(parse_sorry))}, x0); r = r.add({transition("match", mk_ext_action(parse_match))}, x0); + init_structure_instance_parsing_rules(r); return r; } diff --git a/src/frontends/lean/elaborator.cpp b/src/frontends/lean/elaborator.cpp index 3ad60ca03..05d881774 100644 --- a/src/frontends/lean/elaborator.cpp +++ b/src/frontends/lean/elaborator.cpp @@ -40,6 +40,7 @@ Author: Leonardo de Moura #include "library/error_handling/error_handling.h" #include "library/definitional/equations.h" #include "frontends/lean/local_decls.h" +#include "frontends/lean/structure_cmd.h" #include "frontends/lean/class.h" #include "frontends/lean/tactic_hint.h" #include "frontends/lean/info_manager.h" @@ -1181,6 +1182,120 @@ expr elaborator::visit_decreasing(expr const & e, constraint_seq & cs) { return mk_decreasing(dec_app, dec_proof); } +bool elaborator::is_structure(expr const & S) { + expr const & I = get_app_fn(S); + return is_constant(I) && + inductive::is_inductive_decl(env(), const_name(I)) && + *inductive::get_num_intro_rules(env(), const_name(I)) == 1 && + *inductive::get_num_indices(env(), const_name(I)) == 0; +} + +expr elaborator::visit_structure_instance(expr const & e, constraint_seq & cs) { + expr S; + buffer field_names; + buffer field_values, using_exprs; + destruct_structure_instance(e, S, field_names, field_values, using_exprs); + lean_assert(field_names.size() == field_values.size()); + expr new_S = visit(S, cs); + if (!is_structure(new_S)) + throw_elaborator_exception("invalid structure instance, given type is not a structure", S); + buffer new_S_args; + expr I = get_app_args(new_S, new_S_args); + expr new_S_type = whnf(infer_type(new_S, cs), cs); + tag S_tag = S.get_tag(); + while (is_pi(new_S_type)) { + expr m = m_full_context.mk_meta(m_ngen, some_expr(binding_domain(new_S_type)), S_tag); + register_meta(m); + new_S_args.push_back(m); + new_S = mk_app(new_S, m, S_tag); + new_S_type = whnf(instantiate(binding_body(new_S_type), m), cs); + } + buffer field_used; + field_used.resize(field_names.size(), false); + buffer new_field_values; + for (expr const & v : field_values) + new_field_values.push_back(visit(v, cs)); + buffer using_exprs_used; + using_exprs_used.resize(using_exprs.size(), false); + buffer new_using_exprs; + buffer new_using_types; + for (expr const & u : using_exprs) { + expr new_u = visit(u, cs); + expr new_u_type = whnf(infer_type(new_u, cs), cs); + if (!is_structure(new_u_type)) + throw_elaborator_exception("invalid structure instance, type of 'using' argument is not a structure", u); + new_using_exprs.push_back(new_u); + new_using_types.push_back(new_u_type); + } + buffer intro_names; + get_intro_rule_names(env(), const_name(I), intro_names); + lean_assert(intro_names.size() == 1); + name const & S_mk_name = intro_names[0]; + tag result_tag = e.get_tag(); + expr S_mk = mk_constant(S_mk_name, const_levels(I), result_tag); + for (expr & arg : new_S_args) + S_mk = mk_app(S_mk, arg, result_tag); + expr S_mk_type = whnf(infer_type(S_mk, cs), cs); + while (is_pi(S_mk_type)) { + name n = binding_name(S_mk_type); + expr d_type = binding_domain(S_mk_type); + expr v; + unsigned i = 0; + for (; i < field_names.size(); i++) { + if (!field_used[i] && field_names[i] == n) { + field_used[i] = true; + v = new_field_values[i]; + break; + } + } + if (i == new_field_values.size()) { + // did not find explicit field + unsigned i = 0; + for (; i < new_using_exprs.size(); i++) { + // check if u_type structure has the given field. + expr const & u_type = new_using_types[i]; + buffer u_type_args; + expr const & J = get_app_args(u_type, u_type_args); + lean_assert(is_constant(J)); + name J_field_name = const_name(J) + n; + if (env().find(J_field_name)) { + tag u_tag = using_exprs[i].get_tag(); + v = mk_constant(J_field_name, const_levels(J), u_tag); + for (expr const & arg : u_type_args) + v = mk_app(v, arg, u_tag); + v = mk_app(v, new_using_exprs[i], u_tag); + using_exprs_used[i] = true; + break; + } + } + if (i == using_exprs.size()) { + // did not find field is using structure + v = m_full_context.mk_meta(m_ngen, some_expr(d_type), result_tag); + register_meta(v); + } + } + S_mk = mk_app(S_mk, v, result_tag); + expr v_type = infer_type(v, cs); + justification j = mk_app_justification(S_mk, v, d_type, v_type); + auto new_v_cs = ensure_has_type(v, v_type, d_type, j, m_relax_main_opaque); + expr new_v = new_v_cs.first; + cs += new_v_cs.second; + S_mk = update_app(S_mk, app_fn(S_mk), new_v); + S_mk_type = whnf(instantiate(binding_body(S_mk_type), new_v), cs); + } + for (unsigned i = 0; i < field_used.size(); i++) { + if (!field_used[i]) + throw_elaborator_exception(sstream() << "invalid structure instance, invalid field name '" + << field_names[i] << "'", field_values[i]); + } + for (unsigned i = 0; i < using_exprs_used.size(); i++) { + if (!using_exprs_used[i]) + throw_elaborator_exception(sstream() << "invalid structure instance, 'using' clause #" + << i + 1 << " is unnecessary", using_exprs[i]); + } + return S_mk; +} + expr elaborator::visit_core(expr const & e, constraint_seq & cs) { if (is_placeholder(e)) { return visit_placeholder(e, cs); @@ -1218,6 +1333,8 @@ expr elaborator::visit_core(expr const & e, constraint_seq & cs) { return visit_inaccessible(e, cs); } else if (is_decreasing(e)) { return visit_decreasing(e, cs); + } else if (is_structure_instance(e)) { + return visit_structure_instance(e, cs); } else { switch (e.kind()) { case expr_kind::Local: return e; diff --git a/src/frontends/lean/elaborator.h b/src/frontends/lean/elaborator.h index c53c2df6a..2541eca60 100644 --- a/src/frontends/lean/elaborator.h +++ b/src/frontends/lean/elaborator.h @@ -167,6 +167,9 @@ class elaborator : public coercion_info_manager { expr visit_decreasing(expr const & e, constraint_seq & cs); constraint mk_equations_cnstr(expr const & m, expr const & eqns); + bool is_structure(expr const & S); + expr visit_structure_instance(expr const & e, constraint_seq & cs); + public: elaborator(elaborator_context & ctx, name_generator const & ngen, bool nice_mvar_names = false); std::tuple operator()(list const & ctx, expr const & e, bool _ensure_type, diff --git a/src/frontends/lean/structure_cmd.cpp b/src/frontends/lean/structure_cmd.cpp index 443a17ee8..9ec24bf01 100644 --- a/src/frontends/lean/structure_cmd.cpp +++ b/src/frontends/lean/structure_cmd.cpp @@ -8,6 +8,7 @@ Author: Leonardo de Moura #include #include #include +#include #include "util/sstream.h" #include "util/sexpr/option_declarations.h" #include "kernel/type_checker.h" @@ -28,6 +29,7 @@ Author: Leonardo de Moura #include "library/protected.h" #include "library/class.h" #include "library/util.h" +#include "library/kernel_serializer.h" #include "library/definitional/rec_on.h" #include "library/definitional/induction_on.h" #include "library/definitional/cases_on.h" @@ -57,22 +59,6 @@ static name * g_tmp_prefix = nullptr; static name * g_gen_eta = nullptr; static name * g_gen_proj_mk = nullptr; -void initialize_structure_cmd() { - g_tmp_prefix = new name(name::mk_internal_unique_name()); - g_gen_eta = new name{"structure", "eta_thm"}; - g_gen_proj_mk = new name{"structure", "proj_mk_thm"}; - register_bool_option(*g_gen_eta, LEAN_DEFAULT_STRUCTURE_ETA, - "(structure) automatically generate 'eta' theorem whenever declaring a new structure"); - register_bool_option(*g_gen_proj_mk, LEAN_DEFAULT_STRUCTURE_PROJ_MK, - "(structure) automatically gneerate projection over introduction theorem when declaring a new structure, the theorem is never generated for proof irrelevant fields"); -} - -void finalize_structure_cmd() { - delete g_tmp_prefix; - delete g_gen_eta; - delete g_gen_proj_mk; -} - bool get_structure_eta_thm(options const & o) { return o.get_bool(*g_gen_eta, LEAN_DEFAULT_STRUCTURE_ETA); } bool get_structure_proj_mk_thm(options const & o) { return o.get_bool(*g_gen_proj_mk, LEAN_DEFAULT_STRUCTURE_ETA); } @@ -162,7 +148,8 @@ struct structure_cmd_fn { /** \brief Parse structure parameters */ void parse_params() { - if (!m_p.curr_is_token(get_extends_tk()) && !m_p.curr_is_token(get_assign_tk()) && !m_p.curr_is_token(get_colon_tk())) { + if (!m_p.curr_is_token(get_extends_tk()) && !m_p.curr_is_token(get_assign_tk()) && + !m_p.curr_is_token(get_colon_tk())) { unsigned rbp = 0; m_p.parse_binders(m_params, rbp); } @@ -883,7 +870,8 @@ struct structure_cmd_fn { if (m_p.curr_is_token(get_assign_tk())) { m_p.check_token_next(get_assign_tk(), "invalid 'structure', ':=' expected"); m_mk_pos = m_p.pos(); - if (m_p.curr_is_token(get_lparen_tk()) || m_p.curr_is_token(get_lcurly_tk()) || m_p.curr_is_token(get_lbracket_tk())) { + if (m_p.curr_is_token(get_lparen_tk()) || m_p.curr_is_token(get_lcurly_tk()) || + m_p.curr_is_token(get_lbracket_tk())) { m_mk_short = LEAN_DEFAULT_STRUCTURE_INTRO; m_mk_infer = implicit_infer_kind::Implicit; } else { @@ -935,7 +923,130 @@ void get_structure_fields(environment const & env, name const & S, buffer } } +static name * g_structure_instance_name = nullptr; +static std::string * g_structure_instance_opcode = nullptr; + void register_structure_cmd(cmd_table & r) { add_cmd(r, cmd_info("structure", "declare a new structure/record type", structure_cmd)); } + +[[ noreturn ]] static void throw_se_ex() { throw exception("unexpected occurrence of 'structure instance' expression"); } + +// We encode a 'structure instance' expression using a macro. +// This is a trick to avoid creating a new kind of expression. +// 'Structure instance' expressions are temporary objects used by the elaborator. +// Example: Given +// structure point (A B : Type) := (x : A) (y : B) +// the structure instance +// {| point, x := 10, y := 20 |} +// is compiled into +// point.mk 10 20 +class structure_instance_macro_cell : public macro_definition_cell { + list m_fields; +public: + structure_instance_macro_cell(list const & fs):m_fields(fs) {} + virtual name get_name() const { return *g_structure_instance_name; } + virtual pair get_type(expr const &, extension_context &) const { throw_se_ex(); } + virtual optional expand(expr const &, extension_context &) const { throw_se_ex(); } + virtual void write(serializer & s) const { + s << *g_structure_instance_opcode; + write_list(s, m_fields); + } + list const & get_field_names() const { return m_fields; } +}; + +static expr mk_structure_instance(list const & fs, unsigned num, expr const * args) { + lean_assert(num >= length(fs) + 1); + macro_definition def(new structure_instance_macro_cell(fs)); + return mk_macro(def, num, args); +} + +bool is_structure_instance(expr const & e) { + return is_macro(e) && macro_def(e).get_name() == *g_structure_instance_name; +} + +void destruct_structure_instance(expr const & e, expr & t, buffer & field_names, + buffer & field_values, buffer & using_exprs) { + lean_assert(is_structure_instance(e)); + t = macro_arg(e, 0); + list const & fns = static_cast(macro_def(e).raw())->get_field_names(); + unsigned num_fileds = length(fns); + to_buffer(fns, field_names); + for (unsigned i = 1; i < num_fileds+1; i++) + field_values.push_back(macro_arg(e, i)); + for (unsigned i = num_fileds+1; i < macro_num_args(e); i++) + using_exprs.push_back(macro_arg(e, i)); +} + +static expr parse_struct_expr_core(parser & p, pos_info const & pos, bool curly_bar) { + expr t = p.parse_expr(); + buffer field_names; + buffer field_values; + buffer using_exprs; + while (p.curr_is_token(get_comma_tk())) { + p.next(); + if (p.curr_is_token(get_using_tk())) { + p.next(); + using_exprs.push_back(p.parse_expr()); + } else { + field_names.push_back(p.check_atomic_id_next("invalid structure instance, identifier expected")); + p.check_token_next(get_assign_tk(), "invalid structure instance, ':=' expected"); + field_values.push_back(p.parse_expr()); + } + } + if (curly_bar) + p.check_token_next(get_rcurlybar_tk(), "invalid structure expression, '|}' expected"); + else + p.check_token_next(get_rdcurly_tk(), "invalid structure expression, '⦄' expected"); + buffer args; + args.push_back(t); + args.append(field_values); + args.append(using_exprs); + return p.save_pos(mk_structure_instance(to_list(field_names), args.size(), args.data()), pos); +} + +static expr parse_struct_curly_bar(parser & p, unsigned, expr const *, pos_info const & pos) { + bool curly_bar = true; + return parse_struct_expr_core(p, pos, curly_bar); +} + +static expr parse_struct_dcurly(parser & p, unsigned, expr const *, pos_info const & pos) { + bool curly_bar = false; + return parse_struct_expr_core(p, pos, curly_bar); +} + +void init_structure_instance_parsing_rules(parse_table & r) { + expr x0 = mk_var(0); + r = r.add({notation::transition("{|", notation::mk_ext_action(parse_struct_curly_bar))}, x0); + r = r.add({notation::transition("⦃", notation::mk_ext_action(parse_struct_dcurly))}, x0); +} + +void initialize_structure_cmd() { + g_tmp_prefix = new name(name::mk_internal_unique_name()); + g_gen_eta = new name{"structure", "eta_thm"}; + g_gen_proj_mk = new name{"structure", "proj_mk_thm"}; + register_bool_option(*g_gen_eta, LEAN_DEFAULT_STRUCTURE_ETA, + "(structure) automatically generate 'eta' theorem whenever declaring a new structure"); + register_bool_option(*g_gen_proj_mk, LEAN_DEFAULT_STRUCTURE_PROJ_MK, + "(structure) automatically gneerate projection over introduction theorem when " + "declaring a new structure, the theorem is never generated for proof irrelevant fields"); + g_structure_instance_name = new name("structure instance"); + g_structure_instance_opcode = new std::string("STI"); + register_macro_deserializer(*g_structure_instance_opcode, + [](deserializer & d, unsigned num, expr const * args) { + list fs; + fs = read_list(d); + if (num < length(fs) + 1) + throw corrupted_stream_exception(); + return mk_structure_instance(fs, num, args); + }); +} + +void finalize_structure_cmd() { + delete g_tmp_prefix; + delete g_gen_eta; + delete g_gen_proj_mk; + delete g_structure_instance_opcode; + delete g_structure_instance_name; +} } diff --git a/src/frontends/lean/structure_cmd.h b/src/frontends/lean/structure_cmd.h index 7a88b90e5..90e14af78 100644 --- a/src/frontends/lean/structure_cmd.h +++ b/src/frontends/lean/structure_cmd.h @@ -6,7 +6,12 @@ Author: Leonardo de Moura */ #pragma once #include "frontends/lean/cmd_table.h" +#include "frontends/lean/parse_table.h" namespace lean { +void init_structure_instance_parsing_rules(parse_table & r); +bool is_structure_instance(expr const & e); +void destruct_structure_instance(expr const & e, expr & t, buffer & field_names, + buffer & field_values, buffer & using_exprs); bool is_structure(environment const & env, name const & S); void get_structure_fields(environment const & env, name const & S, buffer & fields); void register_structure_cmd(cmd_table & r); diff --git a/src/frontends/lean/token_table.cpp b/src/frontends/lean/token_table.cpp index 8716735ba..58d9784e6 100644 --- a/src/frontends/lean/token_table.cpp +++ b/src/frontends/lean/token_table.cpp @@ -76,6 +76,7 @@ void init_token_table(token_table & t) { {"if", 0}, {"then", 0}, {"else", 0}, {"by", 0}, {"from", 0}, {"(", g_max_prec}, {")", 0}, {"{", g_max_prec}, {"}", 0}, {"_", g_max_prec}, {"[", g_max_prec}, {"]", 0}, {"⦃", g_max_prec}, {"⦄", 0}, {".{", 0}, {"Type", g_max_prec}, + {"{|", g_max_prec}, {"|}", 0}, {"using", 0}, {"|", 0}, {"!", g_max_prec}, {"with", 0}, {"...", 0}, {",", 0}, {".", 0}, {":", 0}, {"::", 0}, {"calc", 0}, {":=", 0}, {"--", 0}, {"#", 0}, {"(*", 0}, {"/-", 0}, {"begin", g_max_prec}, {"proof", g_max_prec}, {"qed", 0}, {"@", g_max_prec}, diff --git a/src/frontends/lean/tokens.cpp b/src/frontends/lean/tokens.cpp index f602fd793..19c907e8c 100644 --- a/src/frontends/lean/tokens.cpp +++ b/src/frontends/lean/tokens.cpp @@ -18,6 +18,8 @@ static name * g_lcurly = nullptr; static name * g_rcurly = nullptr; static name * g_ldcurly = nullptr; static name * g_rdcurly = nullptr; +static name * g_lcurlybar = nullptr; +static name * g_rcurlybar = nullptr; static name * g_lbracket = nullptr; static name * g_rbracket = nullptr; static name * g_bar = nullptr; @@ -112,6 +114,8 @@ void initialize_tokens() { g_rcurly = new name("}"); g_ldcurly = new name("⦃"); g_rdcurly = new name("⦄"); + g_lcurlybar = new name("{|"); + g_rcurlybar = new name("|}"); g_lbracket = new name("["); g_rbracket = new name("]"); g_bar = new name("|"); @@ -279,6 +283,8 @@ void finalize_tokens() { delete g_lbracket; delete g_rdcurly; delete g_ldcurly; + delete g_rcurlybar; + delete g_lcurlybar; delete g_lcurly; delete g_rcurly; delete g_llevel_curly; @@ -301,6 +307,8 @@ name const & get_lcurly_tk() { return *g_lcurly; } name const & get_rcurly_tk() { return *g_rcurly; } name const & get_ldcurly_tk() { return *g_ldcurly; } name const & get_rdcurly_tk() { return *g_rdcurly; } +name const & get_lcurlybar_tk() { return *g_lcurlybar; } +name const & get_rcurlybar_tk() { return *g_rcurlybar; } name const & get_lbracket_tk() { return *g_lbracket; } name const & get_rbracket_tk() { return *g_rbracket; } name const & get_bar_tk() { return *g_bar; } diff --git a/src/frontends/lean/tokens.h b/src/frontends/lean/tokens.h index 9dbf0a78d..d2000e60a 100644 --- a/src/frontends/lean/tokens.h +++ b/src/frontends/lean/tokens.h @@ -20,6 +20,8 @@ name const & get_lcurly_tk(); name const & get_rcurly_tk(); name const & get_ldcurly_tk(); name const & get_rdcurly_tk(); +name const & get_lcurlybar_tk(); +name const & get_rcurlybar_tk(); name const & get_lbracket_tk(); name const & get_rbracket_tk(); name const & get_bar_tk(); diff --git a/tests/lean/run/struct_inst_exprs.lean b/tests/lean/run/struct_inst_exprs.lean new file mode 100644 index 000000000..d9ef90ea7 --- /dev/null +++ b/tests/lean/run/struct_inst_exprs.lean @@ -0,0 +1,25 @@ +open nat prod + +set_option pp.coercions true + +definition s : nat × nat := {| prod, pr1 := 10, pr2 := 20 |} + +structure test := +(A : Type) (a : A) (B : A → Type) (b : B a) + +definition s2 := ⦃ test, a := 3, b := 10 ⦄ + +eval s2 + +definition s3 := {| test, a := 20, using s2 |} + +eval s3 + +definition s4 := ⦃ test, A := nat, B := λ a, nat, a := 10, b := 10 ⦄ + +definition s5 : Σ a : nat, a > 0 := + ⦃ sigma, pr1 := 10, pr2 := of_is_true trivial ⦄ + +eval s5 + +check ⦃ unit ⦄