diff --git a/src/emacs/lean-syntax.el b/src/emacs/lean-syntax.el index 03b24f9fd..93c36b498 100644 --- a/src/emacs/lean-syntax.el +++ b/src/emacs/lean-syntax.el @@ -15,7 +15,7 @@ "options" "precedence" "postfix" "prefix" "calc_trans" "calc_subst" "calc_refl" "infix" "infixl" "infixr" "notation" "eval" "check" "exit" "coercion" "end" "using" "namespace" "including" "instance" "class" "section" - "set_option" "add_rewrite" "extends") + "set_option" "add_rewrite" "extends" "include" "omit") "lean keywords") (defconst lean-syntax-table diff --git a/src/frontends/lean/decl_cmds.cpp b/src/frontends/lean/decl_cmds.cpp index e03622b1c..a29dc39a9 100644 --- a/src/frontends/lean/decl_cmds.cpp +++ b/src/frontends/lean/decl_cmds.cpp @@ -334,9 +334,6 @@ environment definition_cmd_core(parser & p, bool is_theorem, bool is_opaque, boo erase_local_binder_info(ps); value = Fun(ps, value, p); } - - update_univ_parameters(ls_buffer, collect_univ_params(value, collect_univ_params(type)), p); - ls = to_list(ls_buffer.begin(), ls_buffer.end()); } unsigned end_line = p.pos().first; @@ -361,6 +358,8 @@ environment definition_cmd_core(parser & p, bool is_theorem, bool is_opaque, boo section_value_ps.append(section_ps); erase_local_binder_info(section_value_ps); value = Fun_as_is(section_value_ps, value, p); + update_univ_parameters(ls_buffer, collect_univ_params(value, collect_univ_params(type)), p); + ls = to_list(ls_buffer.begin(), ls_buffer.end()); levels section_ls = collect_section_levels(ls, p); while (!section_ps.empty() && p.is_section_variable(section_ps.back())) section_ps.pop_back(); // we do not fix section variables @@ -368,6 +367,9 @@ environment definition_cmd_core(parser & p, bool is_theorem, bool is_opaque, boo param = mk_explicit(param); expr ref = mk_implicit(mk_app(mk_explicit(mk_constant(real_n, section_ls)), section_ps)); p.add_local_expr(n, ref); + } else { + update_univ_parameters(ls_buffer, collect_univ_params(value, collect_univ_params(type)), p); + ls = to_list(ls_buffer.begin(), ls_buffer.end()); } expr pre_type = type; expr pre_value = value; @@ -488,6 +490,36 @@ environment protected_definition_cmd(parser & p) { return definition_cmd_core(p, is_theorem, is_opaque, false, true); } +environment include_cmd_core(parser & p, bool include) { + if (!p.curr_is_identifier()) + throw parser_error(sstream() << "invalid include/omit command, identifier expected", p.pos()); + while (p.curr_is_identifier()) { + auto pos = p.pos(); + name n = p.get_name_val(); + p.next(); + if (!p.get_local(n)) + throw parser_error(sstream() << "invalid include/omit command, '" << n << "' is not a section parameter/variable", pos); + if (include) { + if (p.is_include_variable(n)) + throw parser_error(sstream() << "invalid include command, '" << n << "' has already been included", pos); + p.include_variable(n); + } else { + if (!p.is_include_variable(n)) + throw parser_error(sstream() << "invalid omit command, '" << n << "' has not been included", pos); + p.omit_variable(n); + } + } + return p.env(); +} + +environment include_cmd(parser & p) { + return include_cmd_core(p, true); +} + +environment omit_cmd(parser & p) { + return include_cmd_core(p, false); +} + void register_decl_cmds(cmd_table & r) { add_cmd(r, cmd_info("universe", "declare a global universe level", universe_cmd)); add_cmd(r, cmd_info("variable", "declare a new variable", variable_cmd)); @@ -502,5 +534,7 @@ void register_decl_cmds(cmd_table & r) { add_cmd(r, cmd_info("private", "add new private definition/theorem", private_definition_cmd)); add_cmd(r, cmd_info("protected", "add new protected definition/theorem", protected_definition_cmd)); add_cmd(r, cmd_info("theorem", "add new theorem", theorem_cmd)); + add_cmd(r, cmd_info("include", "force section parameter/variable to be included", include_cmd)); + add_cmd(r, cmd_info("omit", "undo 'include' command", omit_cmd)); } } diff --git a/src/frontends/lean/inductive_cmd.cpp b/src/frontends/lean/inductive_cmd.cpp index 0bfe4dbd9..0ca96bea8 100644 --- a/src/frontends/lean/inductive_cmd.cpp +++ b/src/frontends/lean/inductive_cmd.cpp @@ -413,6 +413,12 @@ struct inductive_cmd_fn { /** \brief Collect section local parameters used in the inductive decls */ void collect_section_locals(buffer const & decls, expr_struct_set & ls) { + buffer include_vars; + m_p.get_include_variables(include_vars); + for (expr const & param : include_vars) { + collect_locals(mlocal_type(param), ls); + ls.insert(param); + } for (auto const & d : decls) { collect_locals(inductive_decl_type(d), ls); for (auto const & ir : inductive_decl_intros(d)) @@ -665,9 +671,9 @@ struct inductive_cmd_fn { parser::local_scope scope(m_p); parse_inductive_decls(decls); } - include_section_levels(decls); buffer section_params; abstract_section_locals(decls, section_params); + include_section_levels(decls); m_num_params += section_params.size(); declare_inductive_types(decls); unsigned num_univ_params = m_levels.size(); diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index 60dddd4fe..aa71c670f 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -96,6 +96,7 @@ parser::parser(environment const & env, io_state const & ios, m_local_level_decls = s->m_lds; m_local_decls = s->m_eds; m_variables = s->m_vars; + m_include_vars = s->m_include_vars; m_options_stack = s->m_options_stack; } m_num_threads = num_threads; @@ -445,6 +446,12 @@ unsigned parser::get_local_index(name const & n) const { return m_local_decls.find_idx(n); } +void parser::get_include_variables(buffer & vars) const { + m_include_vars.for_each([&](name const & n) { + vars.push_back(*get_local(n)); + }); +} + static unsigned g_level_add_prec = 10; static unsigned g_level_cup_prec = 5; @@ -1352,7 +1359,7 @@ void parser::save_snapshot() { if (!m_snapshot_vector) return; if (m_snapshot_vector->empty() || static_cast(m_snapshot_vector->back().m_line) != m_scanner.get_line()) - m_snapshot_vector->push_back(snapshot(m_env, m_local_level_decls, m_local_decls, m_variables, + m_snapshot_vector->push_back(snapshot(m_env, m_local_level_decls, m_local_decls, m_variables, m_include_vars, m_options_stack, m_ios.get_options(), m_scanner.get_line())); } diff --git a/src/frontends/lean/parser.h b/src/frontends/lean/parser.h index 90fc5b3bb..d2d1307b3 100644 --- a/src/frontends/lean/parser.h +++ b/src/frontends/lean/parser.h @@ -52,15 +52,17 @@ struct snapshot { local_level_decls m_lds; local_expr_decls m_eds; name_set m_vars; // subset of m_eds that is tagged as section variables + name_set m_include_vars; // subset of m_eds that must be includes options_stack m_options_stack; options m_options; unsigned m_line; snapshot():m_line(0) {} snapshot(environment const & env, options const & o):m_env(env), m_options(o), m_line(1) {} snapshot(environment const & env, local_level_decls const & lds, local_expr_decls const & eds, - name_set const & vars, options_stack const & os, options const & opts, + name_set const & vars, name_set const & includes, options_stack const & os, options const & opts, unsigned line): - m_env(env), m_lds(lds), m_eds(eds), m_vars(vars), m_options_stack(os), m_options(opts), m_line(line) {} + m_env(env), m_lds(lds), m_eds(eds), m_vars(vars), m_include_vars(includes), + m_options_stack(os), m_options(opts), m_line(line) {} }; typedef std::vector snapshot_vector; @@ -78,6 +80,7 @@ class parser { local_level_decls m_local_level_decls; local_expr_decls m_local_decls; name_set m_variables; // subset of m_local_decls that is marked as variables + name_set m_include_vars; // subset of m_local_decls that is marked as include options_stack m_options_stack; pos_info m_last_cmd_pos; pos_info m_last_script_pos; @@ -316,6 +319,10 @@ public: void add_local(expr const & p) { return add_local_expr(local_pp_name(p), p); } bool is_section_variable(name const & n) const { return m_variables.contains(n); } bool is_section_variable(expr const & e) const { return is_section_variable(local_pp_name(e)); } + void include_variable(name const & n) { m_include_vars.insert(n); } + void omit_variable(name const & n) { m_include_vars.erase(n); } + bool is_include_variable(name const & n) const { return m_include_vars.contains(n); } + void get_include_variables(buffer & vars) const; /** \brief Position of the local level declaration named \c n in the sequence of local level decls. */ unsigned get_local_level_index(name const & n) const; /** \brief Position of the local declaration named \c n in the sequence of local decls. */ diff --git a/src/frontends/lean/token_table.cpp b/src/frontends/lean/token_table.cpp index be0b98bb2..0f9b405e5 100644 --- a/src/frontends/lean/token_table.cpp +++ b/src/frontends/lean/token_table.cpp @@ -86,7 +86,8 @@ void init_token_table(token_table & t) { "inductive", "record", "renaming", "extends", "structure", "module", "universe", "precedence", "infixl", "infixr", "infix", "postfix", "prefix", "notation", "context", "exit", "set_option", "open", "export", "calc_subst", "calc_refl", "calc_trans", "tactic_hint", - "add_begin_end_tactic", "set_begin_end_tactic", "instance", "class", "#erase_cache", nullptr}; + "add_begin_end_tactic", "set_begin_end_tactic", "instance", "class", + "include", "omit", "#erase_cache", nullptr}; pair aliases[] = {{g_lambda_unicode, "fun"}, {"forall", "Pi"}, {g_forall_unicode, "Pi"}, {g_pi_unicode, "Pi"}, diff --git a/src/frontends/lean/util.cpp b/src/frontends/lean/util.cpp index a2d54c42c..bb925f012 100644 --- a/src/frontends/lean/util.cpp +++ b/src/frontends/lean/util.cpp @@ -75,6 +75,12 @@ levels collect_section_levels(level_param_names const & ls, parser & p) { // Collect local (section) constants occurring in type and value, sort them, and store in section_ps void collect_section_locals(expr const & type, expr const & value, parser const & p, buffer & section_ps) { expr_struct_set ls; + buffer include_vars; + p.get_include_variables(include_vars); + for (expr const & param : include_vars) { + collect_locals(mlocal_type(param), ls); + ls.insert(param); + } collect_locals(type, ls); collect_locals(value, ls); sort_section_params(ls, p, section_ps); diff --git a/tests/lean/omit.lean b/tests/lean/omit.lean new file mode 100644 index 000000000..01b9fc222 --- /dev/null +++ b/tests/lean/omit.lean @@ -0,0 +1,22 @@ +section + parameter A : Type + parameter a : A + parameter c : A + omit A + include A + include A + omit A + parameter B : Type + parameter b : B + parameter d : B + include A + include a + include c + definition foo := b + + inductive tst (C : Type) := + mk : tst C +end + +check foo +check tst diff --git a/tests/lean/omit.lean.expected.out b/tests/lean/omit.lean.expected.out new file mode 100644 index 000000000..c85494d57 --- /dev/null +++ b/tests/lean/omit.lean.expected.out @@ -0,0 +1,4 @@ +omit.lean:5:7: error: invalid omit command, 'A' has not been included +omit.lean:7:10: error: invalid include command, 'A' has already been included +foo : Π (A : Type), A → A → (Π (B : Type), B → B) +tst : Π (A : Type), A → A → Type → Type