fix(frontends/lean): bug when using nested sections and parameters

see tests/lean/run/section4.lean
This commit is contained in:
Leonardo de Moura 2014-10-08 22:21:29 -07:00
parent f7bbe09db2
commit 25fd370c51
7 changed files with 154 additions and 18 deletions

View file

@ -116,14 +116,53 @@ environment namespace_cmd(parser & p) {
return push_scope(p.env(), p.ios(), scope_kind::Namespace, n); return push_scope(p.env(), p.ios(), scope_kind::Namespace, n);
} }
static void redeclare_aliases(parser & p,
list<pair<name, level>> old_level_entries,
list<pair<name, expr>> old_entries) {
environment const & env = p.env();
if (!in_section_or_context(env))
return;
list<pair<name, expr>> new_entries = p.get_local_entries();
buffer<pair<name, expr>> to_redeclare;
name_set popped_locals;
while (!is_eqp(old_entries, new_entries)) {
pair<name, expr> entry = head(old_entries);
if (is_section_local_ref(entry.second))
to_redeclare.push_back(entry);
else if (is_local(entry.second))
popped_locals.insert(mlocal_name(entry.second));
old_entries = tail(old_entries);
}
name_set popped_levels;
list<pair<name, level>> new_level_entries = p.get_local_level_entries();
while (!is_eqp(old_level_entries, new_level_entries)) {
level const & l = head(old_level_entries).second;
if (is_param(l))
popped_levels.insert(param_id(l));
old_level_entries = tail(old_level_entries);
}
for (auto const & entry : to_redeclare) {
expr new_ref = update_section_local_ref(entry.second, popped_levels, popped_locals);
if (!is_constant(new_ref))
p.add_local_expr(entry.first, new_ref);
}
}
environment end_scoped_cmd(parser & p) { environment end_scoped_cmd(parser & p) {
list<pair<name, level>> level_entries = p.get_local_level_entries();
list<pair<name, expr>> entries = p.get_local_entries();
if (in_section_or_context(p.env())) if (in_section_or_context(p.env()))
p.pop_local_scope(); p.pop_local_scope();
if (p.curr_is_identifier()) { if (p.curr_is_identifier()) {
name n = p.check_atomic_id_next("invalid end of scope, atomic identifier expected"); name n = p.check_atomic_id_next("invalid end of scope, atomic identifier expected");
return pop_scope(p.env(), n); environment env = pop_scope(p.env(), n);
redeclare_aliases(p, level_entries, entries);
return env;
} else { } else {
return pop_scope(p.env()); environment env = pop_scope(p.env());
redeclare_aliases(p, level_entries, entries);
return env;
} }
} }

View file

@ -20,29 +20,30 @@ namespace lean {
template<typename V> template<typename V>
class local_decls { class local_decls {
typedef name_map<pair<V, unsigned>> map; typedef name_map<pair<V, unsigned>> map;
typedef list<std::tuple<map, unsigned, list<V>>> scopes; typedef list<pair<name, V>> entries;
map m_map; typedef std::tuple<map, unsigned, entries> scope;
unsigned m_counter; typedef list<scope> scopes;
scopes m_scopes; map m_map;
list<V> m_values; unsigned m_counter;
scopes m_scopes;
entries m_entries;
public: public:
local_decls():m_counter(1) {} local_decls():m_counter(1) {}
local_decls(local_decls const & d):m_map(d.m_map), m_counter(d.m_counter), m_scopes(d.m_scopes) {} local_decls(local_decls const & d):
m_map(d.m_map), m_counter(d.m_counter), m_scopes(d.m_scopes), m_entries(d.m_entries) {}
void insert(name const & k, V const & v) { void insert(name const & k, V const & v) {
m_map.insert(k, mk_pair(v, m_counter)); m_map.insert(k, mk_pair(v, m_counter));
m_counter++; m_counter++;
m_values = cons(v, m_values); m_entries = cons(mk_pair(k, v), m_entries);
} }
V const * find(name const & k) const { auto it = m_map.find(k); return it ? &(it->first) : nullptr; } V const * find(name const & k) const { auto it = m_map.find(k); return it ? &(it->first) : nullptr; }
unsigned find_idx(name const & k) const { auto it = m_map.find(k); return it ? it->second : 0; } unsigned find_idx(name const & k) const { auto it = m_map.find(k); return it ? it->second : 0; }
bool contains(name const & k) const { return m_map.contains(k); } bool contains(name const & k) const { return m_map.contains(k); }
list<V> const & get_values() const { return m_values; } entries const & get_entries() const { return m_entries; }
void push() { m_scopes = scopes(std::make_tuple(m_map, m_counter, m_values), m_scopes); } void push() { m_scopes = cons(scope(m_map, m_counter, m_entries), m_scopes); }
void pop() { void pop() {
lean_assert(!is_nil(m_scopes)); lean_assert(!is_nil(m_scopes));
m_map = std::get<0>(head(m_scopes)); std::tie(m_map, m_counter, m_entries) = head(m_scopes);
m_counter = std::get<1>(head(m_scopes));
m_values = std::get<2>(head(m_scopes));
m_scopes = tail(m_scopes); m_scopes = tail(m_scopes);
} }
struct mk_scope { struct mk_scope {

View file

@ -464,7 +464,11 @@ void parser::get_include_variables(buffer<expr> & vars) const {
} }
list<expr> parser::locals_to_context() const { list<expr> parser::locals_to_context() const {
return filter(m_local_decls.get_values(), [](expr const & e) { return is_local(e); }); return map_filter<expr>(m_local_decls.get_entries(),
[](pair<name, expr> const & p, expr & out) {
out = p.second;
return is_local(p.second);
});
} }
static unsigned g_level_add_prec = 10; static unsigned g_level_add_prec = 10;

View file

@ -348,9 +348,11 @@ public:
expr const * get_local(name const & n) const { return m_local_decls.find(n); } expr const * get_local(name const & n) const { return m_local_decls.find(n); }
/** \brief Return local declarations as a list of local constants. */ /** \brief Return local declarations as a list of local constants. */
list<expr> locals_to_context() const; list<expr> locals_to_context() const;
/** \brief Return all local declarations and aliases */
/** list<pair<name, expr>> const & get_local_entries() const { return m_local_decls.get_entries(); }
\brief By default, when the parser finds a unknown identifier, it signs an error. /** \brief Return all local level declarations */
list<pair<name, level>> const & get_local_level_entries() const { return m_local_level_decls.get_entries(); }
/** \brief By default, when the parser finds a unknown identifier, it signs an error.
This scope object temporarily changes this behavior. In any scope where this object This scope object temporarily changes this behavior. In any scope where this object
is declared, the parse creates a constant even when the identifier is unknown. is declared, the parse creates a constant even when the identifier is unknown.
This behavior is useful when we are trying to parse mutually recursive declarations. This behavior is useful when we are trying to parse mutually recursive declarations.

View file

@ -114,6 +114,61 @@ expr mk_section_local_ref(name const & n, levels const & sec_ls, unsigned num_se
return mk_implicit(mk_app(mk_explicit(mk_constant(n, sec_ls)), params)); return mk_implicit(mk_app(mk_explicit(mk_constant(n, sec_ls)), params));
} }
bool is_section_local_ref(expr const & e) {
if (!is_implicit(e))
return false;
expr const & imp_arg = get_implicit_arg(e);
if (!is_app(imp_arg))
return false;
buffer<expr> locals;
expr const & f = get_app_args(imp_arg, locals);
return
is_explicit(f) &&
is_constant(get_explicit_arg(f)) &&
std::all_of(locals.begin(), locals.end(),
[](expr const & l) {
return is_explicit(l) && is_local(get_explicit_arg(l));
});
}
expr update_section_local_ref(expr const & e, name_set const & lvls_to_remove, name_set const & locals_to_remove) {
lean_assert(is_section_local_ref(e));
if (locals_to_remove.empty() && lvls_to_remove.empty())
return e;
buffer<expr> locals;
expr const & f = get_app_args(get_implicit_arg(e), locals);
lean_assert(is_explicit(f));
expr new_f;
if (!lvls_to_remove.empty()) {
expr const & c = get_explicit_arg(f);
lean_assert(is_constant(c));
new_f = mk_explicit(update_constant(c, filter(const_levels(c), [&](level const & l) {
return is_param(l) && !lvls_to_remove.contains(param_id(l));
})));
} else {
new_f = f;
}
if (!locals_to_remove.empty()) {
unsigned j = 0;
for (unsigned i = 0; i < locals.size(); i++) {
expr const & l = locals[i];
if (!locals_to_remove.contains(mlocal_name(get_explicit_arg(l)))) {
locals[j] = l;
j++;
}
}
locals.shrink(j);
}
if (locals.empty()) {
return get_explicit_arg(new_f);
} else {
return mk_implicit(mk_app(new_f, locals));
}
}
expr Fun(buffer<expr> const & locals, expr const & e, parser & p) { expr Fun(buffer<expr> const & locals, expr const & e, parser & p) {
bool use_cache = false; bool use_cache = false;
return p.rec_save_pos(Fun(locals, e, use_cache), p.pos_of(e)); return p.rec_save_pos(Fun(locals, e, use_cache), p.pos_of(e));

View file

@ -36,6 +36,21 @@ list<expr> locals_to_context(expr const & e, parser const & p);
That is, when the user writes \c n inside the section she is really getting the term returned by this function. That is, when the user writes \c n inside the section she is really getting the term returned by this function.
*/ */
expr mk_section_local_ref(name const & n, levels const & sec_ls, unsigned num_sec_params, expr const * sec_params); expr mk_section_local_ref(name const & n, levels const & sec_ls, unsigned num_sec_params, expr const * sec_params);
/** \brief Return true iff \c e is a term of the form
<tt>(@^-1 (@n.{ls} @l_1 ... @l_n))</tt> where
\c n is a constant and l_i's are local constants.
\remark is_section_local_ref(mk_section_local_ref(n, ls, num_ps, ps)) always hold.
*/
bool is_section_local_ref(expr const & e);
/** \brief Given a term \c e s.t. is_section_local_ref(e) is true, remove all local constants in \c to_remove.
That is, if \c e is of the form
<tt>(@^-1 (@n.{u_1 ... u_k} @l_1 ... @l_n))</tt>
Then, return a term s.t.
1) any l_i s.t. mlocal_name(l_i) in \c locals_to_remove is removed.
2) any level u_j in \c lvls_to_remove is removed
*/
expr update_section_local_ref(expr const & e, name_set const & lvls_to_remove, name_set const & locals_to_remove);
/** \brief Fun(locals, e), but also propagate \c e position to result */ /** \brief Fun(locals, e), but also propagate \c e position to result */
expr Fun(buffer<expr> const & locals, expr const & e, parser & p); expr Fun(buffer<expr> const & locals, expr const & e, parser & p);

View file

@ -0,0 +1,20 @@
import logic
section
universe k
parameter A : Type
section
universe l
universe u
parameter B : Type
definition foo (a : A) (b : B) := b
inductive mypair :=
mk : A → B → mypair
end
variable a : A
check foo num a 0
definition pr1 (p : mypair num) : A := mypair.rec (λ a b, a) p
definition pr2 (p : mypair num) : num := mypair.rec (λ a b, b) p
end