fix(frontends/lean): bug when using nested sections and parameters
see tests/lean/run/section4.lean
This commit is contained in:
parent
f7bbe09db2
commit
25fd370c51
7 changed files with 154 additions and 18 deletions
|
@ -116,14 +116,53 @@ environment namespace_cmd(parser & p) {
|
|||
return push_scope(p.env(), p.ios(), scope_kind::Namespace, n);
|
||||
}
|
||||
|
||||
static void redeclare_aliases(parser & p,
|
||||
list<pair<name, level>> old_level_entries,
|
||||
list<pair<name, expr>> old_entries) {
|
||||
environment const & env = p.env();
|
||||
if (!in_section_or_context(env))
|
||||
return;
|
||||
list<pair<name, expr>> new_entries = p.get_local_entries();
|
||||
buffer<pair<name, expr>> to_redeclare;
|
||||
name_set popped_locals;
|
||||
while (!is_eqp(old_entries, new_entries)) {
|
||||
pair<name, expr> entry = head(old_entries);
|
||||
if (is_section_local_ref(entry.second))
|
||||
to_redeclare.push_back(entry);
|
||||
else if (is_local(entry.second))
|
||||
popped_locals.insert(mlocal_name(entry.second));
|
||||
old_entries = tail(old_entries);
|
||||
}
|
||||
name_set popped_levels;
|
||||
list<pair<name, level>> new_level_entries = p.get_local_level_entries();
|
||||
while (!is_eqp(old_level_entries, new_level_entries)) {
|
||||
level const & l = head(old_level_entries).second;
|
||||
if (is_param(l))
|
||||
popped_levels.insert(param_id(l));
|
||||
old_level_entries = tail(old_level_entries);
|
||||
}
|
||||
|
||||
for (auto const & entry : to_redeclare) {
|
||||
expr new_ref = update_section_local_ref(entry.second, popped_levels, popped_locals);
|
||||
if (!is_constant(new_ref))
|
||||
p.add_local_expr(entry.first, new_ref);
|
||||
}
|
||||
}
|
||||
|
||||
environment end_scoped_cmd(parser & p) {
|
||||
list<pair<name, level>> level_entries = p.get_local_level_entries();
|
||||
list<pair<name, expr>> entries = p.get_local_entries();
|
||||
if (in_section_or_context(p.env()))
|
||||
p.pop_local_scope();
|
||||
if (p.curr_is_identifier()) {
|
||||
name n = p.check_atomic_id_next("invalid end of scope, atomic identifier expected");
|
||||
return pop_scope(p.env(), n);
|
||||
environment env = pop_scope(p.env(), n);
|
||||
redeclare_aliases(p, level_entries, entries);
|
||||
return env;
|
||||
} else {
|
||||
return pop_scope(p.env());
|
||||
environment env = pop_scope(p.env());
|
||||
redeclare_aliases(p, level_entries, entries);
|
||||
return env;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,29 +20,30 @@ namespace lean {
|
|||
template<typename V>
|
||||
class local_decls {
|
||||
typedef name_map<pair<V, unsigned>> map;
|
||||
typedef list<std::tuple<map, unsigned, list<V>>> scopes;
|
||||
map m_map;
|
||||
unsigned m_counter;
|
||||
scopes m_scopes;
|
||||
list<V> m_values;
|
||||
typedef list<pair<name, V>> entries;
|
||||
typedef std::tuple<map, unsigned, entries> scope;
|
||||
typedef list<scope> scopes;
|
||||
map m_map;
|
||||
unsigned m_counter;
|
||||
scopes m_scopes;
|
||||
entries m_entries;
|
||||
public:
|
||||
local_decls():m_counter(1) {}
|
||||
local_decls(local_decls const & d):m_map(d.m_map), m_counter(d.m_counter), m_scopes(d.m_scopes) {}
|
||||
local_decls(local_decls const & d):
|
||||
m_map(d.m_map), m_counter(d.m_counter), m_scopes(d.m_scopes), m_entries(d.m_entries) {}
|
||||
void insert(name const & k, V const & v) {
|
||||
m_map.insert(k, mk_pair(v, m_counter));
|
||||
m_counter++;
|
||||
m_values = cons(v, m_values);
|
||||
m_entries = cons(mk_pair(k, v), m_entries);
|
||||
}
|
||||
V const * find(name const & k) const { auto it = m_map.find(k); return it ? &(it->first) : nullptr; }
|
||||
unsigned find_idx(name const & k) const { auto it = m_map.find(k); return it ? it->second : 0; }
|
||||
bool contains(name const & k) const { return m_map.contains(k); }
|
||||
list<V> const & get_values() const { return m_values; }
|
||||
void push() { m_scopes = scopes(std::make_tuple(m_map, m_counter, m_values), m_scopes); }
|
||||
entries const & get_entries() const { return m_entries; }
|
||||
void push() { m_scopes = cons(scope(m_map, m_counter, m_entries), m_scopes); }
|
||||
void pop() {
|
||||
lean_assert(!is_nil(m_scopes));
|
||||
m_map = std::get<0>(head(m_scopes));
|
||||
m_counter = std::get<1>(head(m_scopes));
|
||||
m_values = std::get<2>(head(m_scopes));
|
||||
std::tie(m_map, m_counter, m_entries) = head(m_scopes);
|
||||
m_scopes = tail(m_scopes);
|
||||
}
|
||||
struct mk_scope {
|
||||
|
|
|
@ -464,7 +464,11 @@ void parser::get_include_variables(buffer<expr> & vars) const {
|
|||
}
|
||||
|
||||
list<expr> parser::locals_to_context() const {
|
||||
return filter(m_local_decls.get_values(), [](expr const & e) { return is_local(e); });
|
||||
return map_filter<expr>(m_local_decls.get_entries(),
|
||||
[](pair<name, expr> const & p, expr & out) {
|
||||
out = p.second;
|
||||
return is_local(p.second);
|
||||
});
|
||||
}
|
||||
|
||||
static unsigned g_level_add_prec = 10;
|
||||
|
|
|
@ -348,9 +348,11 @@ public:
|
|||
expr const * get_local(name const & n) const { return m_local_decls.find(n); }
|
||||
/** \brief Return local declarations as a list of local constants. */
|
||||
list<expr> locals_to_context() const;
|
||||
|
||||
/**
|
||||
\brief By default, when the parser finds a unknown identifier, it signs an error.
|
||||
/** \brief Return all local declarations and aliases */
|
||||
list<pair<name, expr>> const & get_local_entries() const { return m_local_decls.get_entries(); }
|
||||
/** \brief Return all local level declarations */
|
||||
list<pair<name, level>> const & get_local_level_entries() const { return m_local_level_decls.get_entries(); }
|
||||
/** \brief By default, when the parser finds a unknown identifier, it signs an error.
|
||||
This scope object temporarily changes this behavior. In any scope where this object
|
||||
is declared, the parse creates a constant even when the identifier is unknown.
|
||||
This behavior is useful when we are trying to parse mutually recursive declarations.
|
||||
|
|
|
@ -114,6 +114,61 @@ expr mk_section_local_ref(name const & n, levels const & sec_ls, unsigned num_se
|
|||
return mk_implicit(mk_app(mk_explicit(mk_constant(n, sec_ls)), params));
|
||||
}
|
||||
|
||||
bool is_section_local_ref(expr const & e) {
|
||||
if (!is_implicit(e))
|
||||
return false;
|
||||
expr const & imp_arg = get_implicit_arg(e);
|
||||
if (!is_app(imp_arg))
|
||||
return false;
|
||||
buffer<expr> locals;
|
||||
expr const & f = get_app_args(imp_arg, locals);
|
||||
return
|
||||
is_explicit(f) &&
|
||||
is_constant(get_explicit_arg(f)) &&
|
||||
std::all_of(locals.begin(), locals.end(),
|
||||
[](expr const & l) {
|
||||
return is_explicit(l) && is_local(get_explicit_arg(l));
|
||||
});
|
||||
}
|
||||
|
||||
expr update_section_local_ref(expr const & e, name_set const & lvls_to_remove, name_set const & locals_to_remove) {
|
||||
lean_assert(is_section_local_ref(e));
|
||||
if (locals_to_remove.empty() && lvls_to_remove.empty())
|
||||
return e;
|
||||
buffer<expr> locals;
|
||||
expr const & f = get_app_args(get_implicit_arg(e), locals);
|
||||
lean_assert(is_explicit(f));
|
||||
|
||||
expr new_f;
|
||||
if (!lvls_to_remove.empty()) {
|
||||
expr const & c = get_explicit_arg(f);
|
||||
lean_assert(is_constant(c));
|
||||
new_f = mk_explicit(update_constant(c, filter(const_levels(c), [&](level const & l) {
|
||||
return is_param(l) && !lvls_to_remove.contains(param_id(l));
|
||||
})));
|
||||
} else {
|
||||
new_f = f;
|
||||
}
|
||||
|
||||
if (!locals_to_remove.empty()) {
|
||||
unsigned j = 0;
|
||||
for (unsigned i = 0; i < locals.size(); i++) {
|
||||
expr const & l = locals[i];
|
||||
if (!locals_to_remove.contains(mlocal_name(get_explicit_arg(l)))) {
|
||||
locals[j] = l;
|
||||
j++;
|
||||
}
|
||||
}
|
||||
locals.shrink(j);
|
||||
}
|
||||
|
||||
if (locals.empty()) {
|
||||
return get_explicit_arg(new_f);
|
||||
} else {
|
||||
return mk_implicit(mk_app(new_f, locals));
|
||||
}
|
||||
}
|
||||
|
||||
expr Fun(buffer<expr> const & locals, expr const & e, parser & p) {
|
||||
bool use_cache = false;
|
||||
return p.rec_save_pos(Fun(locals, e, use_cache), p.pos_of(e));
|
||||
|
|
|
@ -36,6 +36,21 @@ list<expr> locals_to_context(expr const & e, parser const & p);
|
|||
That is, when the user writes \c n inside the section she is really getting the term returned by this function.
|
||||
*/
|
||||
expr mk_section_local_ref(name const & n, levels const & sec_ls, unsigned num_sec_params, expr const * sec_params);
|
||||
/** \brief Return true iff \c e is a term of the form
|
||||
<tt>(@^-1 (@n.{ls} @l_1 ... @l_n))</tt> where
|
||||
\c n is a constant and l_i's are local constants.
|
||||
|
||||
\remark is_section_local_ref(mk_section_local_ref(n, ls, num_ps, ps)) always hold.
|
||||
*/
|
||||
bool is_section_local_ref(expr const & e);
|
||||
/** \brief Given a term \c e s.t. is_section_local_ref(e) is true, remove all local constants in \c to_remove.
|
||||
That is, if \c e is of the form
|
||||
<tt>(@^-1 (@n.{u_1 ... u_k} @l_1 ... @l_n))</tt>
|
||||
Then, return a term s.t.
|
||||
1) any l_i s.t. mlocal_name(l_i) in \c locals_to_remove is removed.
|
||||
2) any level u_j in \c lvls_to_remove is removed
|
||||
*/
|
||||
expr update_section_local_ref(expr const & e, name_set const & lvls_to_remove, name_set const & locals_to_remove);
|
||||
|
||||
/** \brief Fun(locals, e), but also propagate \c e position to result */
|
||||
expr Fun(buffer<expr> const & locals, expr const & e, parser & p);
|
||||
|
|
20
tests/lean/run/section4.lean
Normal file
20
tests/lean/run/section4.lean
Normal file
|
@ -0,0 +1,20 @@
|
|||
import logic
|
||||
|
||||
section
|
||||
universe k
|
||||
parameter A : Type
|
||||
|
||||
section
|
||||
universe l
|
||||
universe u
|
||||
parameter B : Type
|
||||
definition foo (a : A) (b : B) := b
|
||||
|
||||
inductive mypair :=
|
||||
mk : A → B → mypair
|
||||
end
|
||||
variable a : A
|
||||
check foo num a 0
|
||||
definition pr1 (p : mypair num) : A := mypair.rec (λ a b, a) p
|
||||
definition pr2 (p : mypair num) : num := mypair.rec (λ a b, b) p
|
||||
end
|
Loading…
Reference in a new issue