diff --git a/src/frontends/lean/builtin_cmds.cpp b/src/frontends/lean/builtin_cmds.cpp index 49e35cfb6..1a027ee5e 100644 --- a/src/frontends/lean/builtin_cmds.cpp +++ b/src/frontends/lean/builtin_cmds.cpp @@ -116,14 +116,53 @@ environment namespace_cmd(parser & p) { return push_scope(p.env(), p.ios(), scope_kind::Namespace, n); } +static void redeclare_aliases(parser & p, + list> old_level_entries, + list> old_entries) { + environment const & env = p.env(); + if (!in_section_or_context(env)) + return; + list> new_entries = p.get_local_entries(); + buffer> to_redeclare; + name_set popped_locals; + while (!is_eqp(old_entries, new_entries)) { + pair entry = head(old_entries); + if (is_section_local_ref(entry.second)) + to_redeclare.push_back(entry); + else if (is_local(entry.second)) + popped_locals.insert(mlocal_name(entry.second)); + old_entries = tail(old_entries); + } + name_set popped_levels; + list> new_level_entries = p.get_local_level_entries(); + while (!is_eqp(old_level_entries, new_level_entries)) { + level const & l = head(old_level_entries).second; + if (is_param(l)) + popped_levels.insert(param_id(l)); + old_level_entries = tail(old_level_entries); + } + + for (auto const & entry : to_redeclare) { + expr new_ref = update_section_local_ref(entry.second, popped_levels, popped_locals); + if (!is_constant(new_ref)) + p.add_local_expr(entry.first, new_ref); + } +} + environment end_scoped_cmd(parser & p) { + list> level_entries = p.get_local_level_entries(); + list> entries = p.get_local_entries(); if (in_section_or_context(p.env())) p.pop_local_scope(); if (p.curr_is_identifier()) { name n = p.check_atomic_id_next("invalid end of scope, atomic identifier expected"); - return pop_scope(p.env(), n); + environment env = pop_scope(p.env(), n); + redeclare_aliases(p, level_entries, entries); + return env; } else { - return pop_scope(p.env()); + environment env = pop_scope(p.env()); + redeclare_aliases(p, level_entries, entries); + return env; } } diff --git a/src/frontends/lean/local_decls.h b/src/frontends/lean/local_decls.h index 56df2dea6..407370c2a 100644 --- a/src/frontends/lean/local_decls.h +++ b/src/frontends/lean/local_decls.h @@ -20,29 +20,30 @@ namespace lean { template class local_decls { typedef name_map> map; - typedef list>> scopes; - map m_map; - unsigned m_counter; - scopes m_scopes; - list m_values; + typedef list> entries; + typedef std::tuple scope; + typedef list scopes; + map m_map; + unsigned m_counter; + scopes m_scopes; + entries m_entries; public: local_decls():m_counter(1) {} - local_decls(local_decls const & d):m_map(d.m_map), m_counter(d.m_counter), m_scopes(d.m_scopes) {} + local_decls(local_decls const & d): + m_map(d.m_map), m_counter(d.m_counter), m_scopes(d.m_scopes), m_entries(d.m_entries) {} void insert(name const & k, V const & v) { m_map.insert(k, mk_pair(v, m_counter)); m_counter++; - m_values = cons(v, m_values); + m_entries = cons(mk_pair(k, v), m_entries); } V const * find(name const & k) const { auto it = m_map.find(k); return it ? &(it->first) : nullptr; } unsigned find_idx(name const & k) const { auto it = m_map.find(k); return it ? it->second : 0; } bool contains(name const & k) const { return m_map.contains(k); } - list const & get_values() const { return m_values; } - void push() { m_scopes = scopes(std::make_tuple(m_map, m_counter, m_values), m_scopes); } + entries const & get_entries() const { return m_entries; } + void push() { m_scopes = cons(scope(m_map, m_counter, m_entries), m_scopes); } void pop() { lean_assert(!is_nil(m_scopes)); - m_map = std::get<0>(head(m_scopes)); - m_counter = std::get<1>(head(m_scopes)); - m_values = std::get<2>(head(m_scopes)); + std::tie(m_map, m_counter, m_entries) = head(m_scopes); m_scopes = tail(m_scopes); } struct mk_scope { diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index c2b3275b5..6a2cff199 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -464,7 +464,11 @@ void parser::get_include_variables(buffer & vars) const { } list parser::locals_to_context() const { - return filter(m_local_decls.get_values(), [](expr const & e) { return is_local(e); }); + return map_filter(m_local_decls.get_entries(), + [](pair const & p, expr & out) { + out = p.second; + return is_local(p.second); + }); } static unsigned g_level_add_prec = 10; diff --git a/src/frontends/lean/parser.h b/src/frontends/lean/parser.h index fb9e595bd..7fbbd87e7 100644 --- a/src/frontends/lean/parser.h +++ b/src/frontends/lean/parser.h @@ -348,9 +348,11 @@ public: expr const * get_local(name const & n) const { return m_local_decls.find(n); } /** \brief Return local declarations as a list of local constants. */ list locals_to_context() const; - - /** - \brief By default, when the parser finds a unknown identifier, it signs an error. + /** \brief Return all local declarations and aliases */ + list> const & get_local_entries() const { return m_local_decls.get_entries(); } + /** \brief Return all local level declarations */ + list> const & get_local_level_entries() const { return m_local_level_decls.get_entries(); } + /** \brief By default, when the parser finds a unknown identifier, it signs an error. This scope object temporarily changes this behavior. In any scope where this object is declared, the parse creates a constant even when the identifier is unknown. This behavior is useful when we are trying to parse mutually recursive declarations. diff --git a/src/frontends/lean/util.cpp b/src/frontends/lean/util.cpp index 53a2ccbf5..a4a14045a 100644 --- a/src/frontends/lean/util.cpp +++ b/src/frontends/lean/util.cpp @@ -114,6 +114,61 @@ expr mk_section_local_ref(name const & n, levels const & sec_ls, unsigned num_se return mk_implicit(mk_app(mk_explicit(mk_constant(n, sec_ls)), params)); } +bool is_section_local_ref(expr const & e) { + if (!is_implicit(e)) + return false; + expr const & imp_arg = get_implicit_arg(e); + if (!is_app(imp_arg)) + return false; + buffer locals; + expr const & f = get_app_args(imp_arg, locals); + return + is_explicit(f) && + is_constant(get_explicit_arg(f)) && + std::all_of(locals.begin(), locals.end(), + [](expr const & l) { + return is_explicit(l) && is_local(get_explicit_arg(l)); + }); +} + +expr update_section_local_ref(expr const & e, name_set const & lvls_to_remove, name_set const & locals_to_remove) { + lean_assert(is_section_local_ref(e)); + if (locals_to_remove.empty() && lvls_to_remove.empty()) + return e; + buffer locals; + expr const & f = get_app_args(get_implicit_arg(e), locals); + lean_assert(is_explicit(f)); + + expr new_f; + if (!lvls_to_remove.empty()) { + expr const & c = get_explicit_arg(f); + lean_assert(is_constant(c)); + new_f = mk_explicit(update_constant(c, filter(const_levels(c), [&](level const & l) { + return is_param(l) && !lvls_to_remove.contains(param_id(l)); + }))); + } else { + new_f = f; + } + + if (!locals_to_remove.empty()) { + unsigned j = 0; + for (unsigned i = 0; i < locals.size(); i++) { + expr const & l = locals[i]; + if (!locals_to_remove.contains(mlocal_name(get_explicit_arg(l)))) { + locals[j] = l; + j++; + } + } + locals.shrink(j); + } + + if (locals.empty()) { + return get_explicit_arg(new_f); + } else { + return mk_implicit(mk_app(new_f, locals)); + } +} + expr Fun(buffer const & locals, expr const & e, parser & p) { bool use_cache = false; return p.rec_save_pos(Fun(locals, e, use_cache), p.pos_of(e)); diff --git a/src/frontends/lean/util.h b/src/frontends/lean/util.h index 096c9dfdf..6fd0f6bd8 100644 --- a/src/frontends/lean/util.h +++ b/src/frontends/lean/util.h @@ -36,6 +36,21 @@ list locals_to_context(expr const & e, parser const & p); That is, when the user writes \c n inside the section she is really getting the term returned by this function. */ expr mk_section_local_ref(name const & n, levels const & sec_ls, unsigned num_sec_params, expr const * sec_params); +/** \brief Return true iff \c e is a term of the form + (@^-1 (@n.{ls} @l_1 ... @l_n)) where + \c n is a constant and l_i's are local constants. + + \remark is_section_local_ref(mk_section_local_ref(n, ls, num_ps, ps)) always hold. +*/ +bool is_section_local_ref(expr const & e); +/** \brief Given a term \c e s.t. is_section_local_ref(e) is true, remove all local constants in \c to_remove. + That is, if \c e is of the form + (@^-1 (@n.{u_1 ... u_k} @l_1 ... @l_n)) + Then, return a term s.t. + 1) any l_i s.t. mlocal_name(l_i) in \c locals_to_remove is removed. + 2) any level u_j in \c lvls_to_remove is removed +*/ +expr update_section_local_ref(expr const & e, name_set const & lvls_to_remove, name_set const & locals_to_remove); /** \brief Fun(locals, e), but also propagate \c e position to result */ expr Fun(buffer const & locals, expr const & e, parser & p); diff --git a/tests/lean/run/section4.lean b/tests/lean/run/section4.lean new file mode 100644 index 000000000..b5676a8b1 --- /dev/null +++ b/tests/lean/run/section4.lean @@ -0,0 +1,20 @@ +import logic + +section + universe k + parameter A : Type + + section + universe l + universe u + parameter B : Type + definition foo (a : A) (b : B) := b + + inductive mypair := + mk : A → B → mypair + end + variable a : A + check foo num a 0 + definition pr1 (p : mypair num) : A := mypair.rec (λ a b, a) p + definition pr2 (p : mypair num) : num := mypair.rec (λ a b, b) p +end