diff --git a/src/frontends/lean/builtin_cmds.cpp b/src/frontends/lean/builtin_cmds.cpp index d5d536b47..1fe826994 100644 --- a/src/frontends/lean/builtin_cmds.cpp +++ b/src/frontends/lean/builtin_cmds.cpp @@ -80,7 +80,7 @@ static void update_parameters(buffer & ls_buffer, name_set const & found, ls_buffer.push_back(n); }); std::sort(ls_buffer.begin(), ls_buffer.end(), [&](name const & n1, name const & n2) { - return *p.get_local_level_index(n1) < *p.get_local_level_index(n2); + return p.get_local_level_index(n1) < p.get_local_level_index(n2); }); } @@ -149,7 +149,7 @@ static void collect_section_locals(expr const & type, expr const & value, parser section_ps.push_back(*p.get_local(n)); }); std::sort(section_ps.begin(), section_ps.end(), [&](parameter const & p1, parameter const & p2) { - return *p.get_local_index(mlocal_name(p1.m_local)) < *p.get_local_index(mlocal_name(p2.m_local)); + return p.get_local_index(mlocal_name(p1.m_local)) < p.get_local_index(mlocal_name(p2.m_local)); }); } diff --git a/src/frontends/lean/local_decls.h b/src/frontends/lean/local_decls.h new file mode 100644 index 000000000..470d3b523 --- /dev/null +++ b/src/frontends/lean/local_decls.h @@ -0,0 +1,47 @@ +/* +Copyright (c) 2014 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include +#include "util/rb_map.h" +#include "util/name.h" +#include "util/list.h" +#include "util/pair.h" + +namespace lean { +/** + \brief A "scoped" map from name to values. + + It also supports the operation \c find_idx that returns "when a declaration was inserted into the map". +*/ +template +class local_decls { + typedef rb_map, name_quick_cmp> map; + typedef list> scopes; + map m_map; + unsigned m_counter; + scopes m_scopes; +public: + local_decls():m_counter(1) {} + local_decls(local_decls const & d):m_map(d.m_map), m_counter(d.m_counter), m_scopes(d.m_scopes) {} + void insert(name const & k, V const & v) { m_map.insert(k, mk_pair(v, m_counter)); m_counter++; } + V const * find(name const & k) const { auto it = m_map.find(k); return it ? &(it->first) : nullptr; } + unsigned find_idx(name const & k) const { auto it = m_map.find(k); return it ? it->second : 0; } + bool contains(name const & k) const { return m_map.contains(k); } + void push() { m_scopes = scopes(mk_pair(m_map, m_counter), m_scopes); } + void pop() { + lean_assert(!is_nil(m_scopes)); + m_map = head(m_scopes).first; + m_counter = head(m_scopes).second; + m_scopes = tail(m_scopes); + } + struct mk_scope { + local_decls & m_d; + mk_scope(local_decls & d):m_d(d) { m_d.push(); } + ~mk_scope() { m_d.pop(); } + }; +}; +} diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index 80acab732..beb6a2528 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -40,10 +40,12 @@ bool get_parser_show_errors(options const & opts) { return opts.get_bool(g_ parser::parser(environment const & env, io_state const & ios, std::istream & strm, char const * strm_name, - script_state * ss, bool use_exceptions): + script_state * ss, bool use_exceptions, + local_level_decls const & lds, local_expr_decls const & eds): m_env(env), m_ios(ios), m_ss(ss), m_verbose(true), m_use_exceptions(use_exceptions), - m_scanner(strm, strm_name), m_pos_table(std::make_shared()) { + m_scanner(strm, strm_name), m_local_level_decls(lds), m_local_decls(eds), + m_pos_table(std::make_shared()) { m_type_use_placeholder = true; m_found_errors = false; updt_options(); @@ -239,11 +241,11 @@ void parser::add_local_level(name const & n, level const & l) { throw exception(sstream() << "invalid universe declaration, '" << n << "' shadows a global universe"); if (m_local_level_decls.contains(n)) throw exception(sstream() << "invalid universe declaration, '" << n << "' shadows a local universe"); - m_local_level_decls.insert(n, local_level_entry(l, m_local_level_decls.size())); + m_local_level_decls.insert(n, l); } void parser::add_local_expr(name const & n, expr const & e, binder_info const & bi) { - m_local_decls.insert(n, local_entry(parameter(pos(), e, bi), m_local_decls.size())); + m_local_decls.insert(n, parameter(pos(), e, bi)); } void parser::add_local(expr const & e) { @@ -251,26 +253,17 @@ void parser::add_local(expr const & e) { add_local_expr(local_pp_name(e), e); } -optional parser::get_local_level_index(name const & n) const { - auto it = m_local_level_decls.find(n); - if (it != m_local_level_decls.end()) - return optional(it->second.second); - else - return optional(); +unsigned parser::get_local_level_index(name const & n) const { + return m_local_level_decls.find_idx(n); } -optional parser::get_local_index(name const & n) const { - auto it = m_local_decls.find(n); - if (it != m_local_decls.end()) - return optional(it->second.second); - else - return optional(); +unsigned parser::get_local_index(name const & n) const { + return m_local_decls.find_idx(n); } optional parser::get_local(name const & n) const { - auto it = m_local_decls.find(n); - if (it != m_local_decls.end()) - return optional(it->second.first); + if (auto it = m_local_decls.find(n)) + return optional(*it); else return optional(); } @@ -355,9 +348,8 @@ level parser::parse_level_id() { auto p = pos(); name id = get_name_val(); next(); - auto it = m_local_level_decls.find(id); - if (it != m_local_level_decls.end()) - return it->second.first; + if (auto it = m_local_level_decls.find(id)) + return *it; if (m_env.is_universe(id)) return mk_global_univ(id); if (auto it = get_alias_level(m_env, id)) @@ -522,7 +514,7 @@ void parser::parse_binders_core(buffer & r) { } void parser::parse_binders(buffer & r) { - local_decls::mk_scope scope(m_local_decls); + local_expr_decls::mk_scope scope(m_local_decls); unsigned old_sz = r.size(); parse_binders_core(r); if (old_sz == r.size()) @@ -639,7 +631,6 @@ expr parser::parse_id() { auto p = pos(); name id = get_name_val(); next(); - auto it1 = m_local_decls.find(id); buffer lvl_buffer; levels ls; if (curr_is_token(g_llevel_curly)) { @@ -651,8 +642,8 @@ expr parser::parse_id() { ls = to_list(lvl_buffer.begin(), lvl_buffer.end()); } // locals - if (it1 != m_local_decls.end()) - return copy_with_new_pos(propagate_levels(it1->second.first.m_local, ls), p); + if (auto it1 = m_local_decls.find(id)) + return copy_with_new_pos(propagate_levels(it1->m_local, ls), p); optional r; // globals if (m_env.find(id)) @@ -736,7 +727,7 @@ expr parser::parse_scoped_expr(unsigned num_params, parameter const * ps, unsign if (num_params == 0) { return parse_expr(rbp); } else { - local_decls::mk_scope scope(m_local_decls); + local_expr_decls::mk_scope scope(m_local_decls); for (unsigned i = 0; i < num_params; i++) add_local(ps[i].m_local); return parse_expr(rbp); diff --git a/src/frontends/lean/parser.h b/src/frontends/lean/parser.h index eaf3f9c9e..d6ac2451f 100644 --- a/src/frontends/lean/parser.h +++ b/src/frontends/lean/parser.h @@ -8,7 +8,6 @@ Author: Leonardo de Moura #include #include #include -#include "util/scoped_map.h" #include "util/script_state.h" #include "util/name_map.h" #include "util/exception.h" @@ -18,6 +17,7 @@ Author: Leonardo de Moura #include "library/io_state_stream.h" #include "library/kernel_bindings.h" #include "frontends/lean/scanner.h" +#include "frontends/lean/local_decls.h" #include "frontends/lean/parser_config.h" #include "frontends/lean/parser_pos_provider.h" @@ -41,13 +41,10 @@ struct parser_error : public exception { }; struct interrupt_parser {}; +typedef local_decls local_expr_decls; +typedef local_decls local_level_decls; class parser { - typedef std::pair local_entry; - typedef std::pair local_level_entry; - typedef scoped_map local_decls; - typedef scoped_map local_level_decls; - environment m_env; io_state m_ios; script_state * m_ss; @@ -57,8 +54,8 @@ class parser { scanner m_scanner; scanner::token_kind m_curr; - local_decls m_local_decls; local_level_decls m_local_level_decls; + local_expr_decls m_local_decls; pos_info m_last_cmd_pos; pos_info m_last_script_pos; unsigned m_next_tag_idx; @@ -128,11 +125,15 @@ class parser { public: parser(environment const & env, io_state const & ios, std::istream & strm, char const * str_name, - script_state * ss = nullptr, bool use_exceptions = false); + script_state * ss = nullptr, bool use_exceptions = false, + local_level_decls const & lds = local_level_decls(), + local_expr_decls const & eds = local_expr_decls()); environment const & env() const { return m_env; } io_state const & ios() const { return m_ios; } script_state * ss() const { return m_ss; } + local_level_decls const & get_local_level_decls() const { return m_local_level_decls; } + local_expr_decls const & get_local_expr_decls() const { return m_local_decls; } /** \brief Return the current position information */ pos_info pos() const { return pos_info(m_scanner.get_line(), m_scanner.get_pos()); } @@ -195,9 +196,9 @@ public: void add_local_expr(name const & n, expr const & e, binder_info const & bi = binder_info()); void add_local(expr const & t); /** \brief Position of the local level declaration named \c n in the sequence of local level decls. */ - optional get_local_level_index(name const & n) const; + unsigned get_local_level_index(name const & n) const; /** \brief Position of the local declaration named \c n in the sequence of local decls. */ - optional get_local_index(name const & n) const; + unsigned get_local_index(name const & n) const; /** \brief Return the local parameter named \c n */ optional get_local(name const & n) const; /**