refactor(frontends/lean): add local_decls template that is cheap to copy

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-06-14 09:56:05 -07:00
parent 5fee6fd140
commit 6b99a29c2c
4 changed files with 78 additions and 39 deletions

View file

@ -80,7 +80,7 @@ static void update_parameters(buffer<name> & ls_buffer, name_set const & found,
ls_buffer.push_back(n); ls_buffer.push_back(n);
}); });
std::sort(ls_buffer.begin(), ls_buffer.end(), [&](name const & n1, name const & n2) { std::sort(ls_buffer.begin(), ls_buffer.end(), [&](name const & n1, name const & n2) {
return *p.get_local_level_index(n1) < *p.get_local_level_index(n2); return p.get_local_level_index(n1) < p.get_local_level_index(n2);
}); });
} }
@ -149,7 +149,7 @@ static void collect_section_locals(expr const & type, expr const & value, parser
section_ps.push_back(*p.get_local(n)); section_ps.push_back(*p.get_local(n));
}); });
std::sort(section_ps.begin(), section_ps.end(), [&](parameter const & p1, parameter const & p2) { std::sort(section_ps.begin(), section_ps.end(), [&](parameter const & p1, parameter const & p2) {
return *p.get_local_index(mlocal_name(p1.m_local)) < *p.get_local_index(mlocal_name(p2.m_local)); return p.get_local_index(mlocal_name(p1.m_local)) < p.get_local_index(mlocal_name(p2.m_local));
}); });
} }

View file

@ -0,0 +1,47 @@
/*
Copyright (c) 2014 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include <utility>
#include "util/rb_map.h"
#include "util/name.h"
#include "util/list.h"
#include "util/pair.h"
namespace lean {
/**
\brief A "scoped" map from name to values.
It also supports the operation \c find_idx that returns "when a declaration was inserted into the map".
*/
template<typename V>
class local_decls {
typedef rb_map<name, std::pair<V, unsigned>, name_quick_cmp> map;
typedef list<std::pair<map, unsigned>> scopes;
map m_map;
unsigned m_counter;
scopes m_scopes;
public:
local_decls():m_counter(1) {}
local_decls(local_decls const & d):m_map(d.m_map), m_counter(d.m_counter), m_scopes(d.m_scopes) {}
void insert(name const & k, V const & v) { m_map.insert(k, mk_pair(v, m_counter)); m_counter++; }
V const * find(name const & k) const { auto it = m_map.find(k); return it ? &(it->first) : nullptr; }
unsigned find_idx(name const & k) const { auto it = m_map.find(k); return it ? it->second : 0; }
bool contains(name const & k) const { return m_map.contains(k); }
void push() { m_scopes = scopes(mk_pair(m_map, m_counter), m_scopes); }
void pop() {
lean_assert(!is_nil(m_scopes));
m_map = head(m_scopes).first;
m_counter = head(m_scopes).second;
m_scopes = tail(m_scopes);
}
struct mk_scope {
local_decls & m_d;
mk_scope(local_decls & d):m_d(d) { m_d.push(); }
~mk_scope() { m_d.pop(); }
};
};
}

View file

@ -40,10 +40,12 @@ bool get_parser_show_errors(options const & opts) { return opts.get_bool(g_
parser::parser(environment const & env, io_state const & ios, parser::parser(environment const & env, io_state const & ios,
std::istream & strm, char const * strm_name, std::istream & strm, char const * strm_name,
script_state * ss, bool use_exceptions): script_state * ss, bool use_exceptions,
local_level_decls const & lds, local_expr_decls const & eds):
m_env(env), m_ios(ios), m_ss(ss), m_env(env), m_ios(ios), m_ss(ss),
m_verbose(true), m_use_exceptions(use_exceptions), m_verbose(true), m_use_exceptions(use_exceptions),
m_scanner(strm, strm_name), m_pos_table(std::make_shared<pos_info_table>()) { m_scanner(strm, strm_name), m_local_level_decls(lds), m_local_decls(eds),
m_pos_table(std::make_shared<pos_info_table>()) {
m_type_use_placeholder = true; m_type_use_placeholder = true;
m_found_errors = false; m_found_errors = false;
updt_options(); updt_options();
@ -239,11 +241,11 @@ void parser::add_local_level(name const & n, level const & l) {
throw exception(sstream() << "invalid universe declaration, '" << n << "' shadows a global universe"); throw exception(sstream() << "invalid universe declaration, '" << n << "' shadows a global universe");
if (m_local_level_decls.contains(n)) if (m_local_level_decls.contains(n))
throw exception(sstream() << "invalid universe declaration, '" << n << "' shadows a local universe"); throw exception(sstream() << "invalid universe declaration, '" << n << "' shadows a local universe");
m_local_level_decls.insert(n, local_level_entry(l, m_local_level_decls.size())); m_local_level_decls.insert(n, l);
} }
void parser::add_local_expr(name const & n, expr const & e, binder_info const & bi) { void parser::add_local_expr(name const & n, expr const & e, binder_info const & bi) {
m_local_decls.insert(n, local_entry(parameter(pos(), e, bi), m_local_decls.size())); m_local_decls.insert(n, parameter(pos(), e, bi));
} }
void parser::add_local(expr const & e) { void parser::add_local(expr const & e) {
@ -251,26 +253,17 @@ void parser::add_local(expr const & e) {
add_local_expr(local_pp_name(e), e); add_local_expr(local_pp_name(e), e);
} }
optional<unsigned> parser::get_local_level_index(name const & n) const { unsigned parser::get_local_level_index(name const & n) const {
auto it = m_local_level_decls.find(n); return m_local_level_decls.find_idx(n);
if (it != m_local_level_decls.end())
return optional<unsigned>(it->second.second);
else
return optional<unsigned>();
} }
optional<unsigned> parser::get_local_index(name const & n) const { unsigned parser::get_local_index(name const & n) const {
auto it = m_local_decls.find(n); return m_local_decls.find_idx(n);
if (it != m_local_decls.end())
return optional<unsigned>(it->second.second);
else
return optional<unsigned>();
} }
optional<parameter> parser::get_local(name const & n) const { optional<parameter> parser::get_local(name const & n) const {
auto it = m_local_decls.find(n); if (auto it = m_local_decls.find(n))
if (it != m_local_decls.end()) return optional<parameter>(*it);
return optional<parameter>(it->second.first);
else else
return optional<parameter>(); return optional<parameter>();
} }
@ -355,9 +348,8 @@ level parser::parse_level_id() {
auto p = pos(); auto p = pos();
name id = get_name_val(); name id = get_name_val();
next(); next();
auto it = m_local_level_decls.find(id); if (auto it = m_local_level_decls.find(id))
if (it != m_local_level_decls.end()) return *it;
return it->second.first;
if (m_env.is_universe(id)) if (m_env.is_universe(id))
return mk_global_univ(id); return mk_global_univ(id);
if (auto it = get_alias_level(m_env, id)) if (auto it = get_alias_level(m_env, id))
@ -522,7 +514,7 @@ void parser::parse_binders_core(buffer<parameter> & r) {
} }
void parser::parse_binders(buffer<parameter> & r) { void parser::parse_binders(buffer<parameter> & r) {
local_decls::mk_scope scope(m_local_decls); local_expr_decls::mk_scope scope(m_local_decls);
unsigned old_sz = r.size(); unsigned old_sz = r.size();
parse_binders_core(r); parse_binders_core(r);
if (old_sz == r.size()) if (old_sz == r.size())
@ -639,7 +631,6 @@ expr parser::parse_id() {
auto p = pos(); auto p = pos();
name id = get_name_val(); name id = get_name_val();
next(); next();
auto it1 = m_local_decls.find(id);
buffer<level> lvl_buffer; buffer<level> lvl_buffer;
levels ls; levels ls;
if (curr_is_token(g_llevel_curly)) { if (curr_is_token(g_llevel_curly)) {
@ -651,8 +642,8 @@ expr parser::parse_id() {
ls = to_list(lvl_buffer.begin(), lvl_buffer.end()); ls = to_list(lvl_buffer.begin(), lvl_buffer.end());
} }
// locals // locals
if (it1 != m_local_decls.end()) if (auto it1 = m_local_decls.find(id))
return copy_with_new_pos(propagate_levels(it1->second.first.m_local, ls), p); return copy_with_new_pos(propagate_levels(it1->m_local, ls), p);
optional<expr> r; optional<expr> r;
// globals // globals
if (m_env.find(id)) if (m_env.find(id))
@ -736,7 +727,7 @@ expr parser::parse_scoped_expr(unsigned num_params, parameter const * ps, unsign
if (num_params == 0) { if (num_params == 0) {
return parse_expr(rbp); return parse_expr(rbp);
} else { } else {
local_decls::mk_scope scope(m_local_decls); local_expr_decls::mk_scope scope(m_local_decls);
for (unsigned i = 0; i < num_params; i++) for (unsigned i = 0; i < num_params; i++)
add_local(ps[i].m_local); add_local(ps[i].m_local);
return parse_expr(rbp); return parse_expr(rbp);

View file

@ -8,7 +8,6 @@ Author: Leonardo de Moura
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "util/scoped_map.h"
#include "util/script_state.h" #include "util/script_state.h"
#include "util/name_map.h" #include "util/name_map.h"
#include "util/exception.h" #include "util/exception.h"
@ -18,6 +17,7 @@ Author: Leonardo de Moura
#include "library/io_state_stream.h" #include "library/io_state_stream.h"
#include "library/kernel_bindings.h" #include "library/kernel_bindings.h"
#include "frontends/lean/scanner.h" #include "frontends/lean/scanner.h"
#include "frontends/lean/local_decls.h"
#include "frontends/lean/parser_config.h" #include "frontends/lean/parser_config.h"
#include "frontends/lean/parser_pos_provider.h" #include "frontends/lean/parser_pos_provider.h"
@ -41,13 +41,10 @@ struct parser_error : public exception {
}; };
struct interrupt_parser {}; struct interrupt_parser {};
typedef local_decls<parameter> local_expr_decls;
typedef local_decls<level> local_level_decls;
class parser { class parser {
typedef std::pair<parameter, unsigned> local_entry;
typedef std::pair<level, unsigned> local_level_entry;
typedef scoped_map<name, local_entry, name_hash, name_eq> local_decls;
typedef scoped_map<name, local_level_entry, name_hash, name_eq> local_level_decls;
environment m_env; environment m_env;
io_state m_ios; io_state m_ios;
script_state * m_ss; script_state * m_ss;
@ -57,8 +54,8 @@ class parser {
scanner m_scanner; scanner m_scanner;
scanner::token_kind m_curr; scanner::token_kind m_curr;
local_decls m_local_decls;
local_level_decls m_local_level_decls; local_level_decls m_local_level_decls;
local_expr_decls m_local_decls;
pos_info m_last_cmd_pos; pos_info m_last_cmd_pos;
pos_info m_last_script_pos; pos_info m_last_script_pos;
unsigned m_next_tag_idx; unsigned m_next_tag_idx;
@ -128,11 +125,15 @@ class parser {
public: public:
parser(environment const & env, io_state const & ios, parser(environment const & env, io_state const & ios,
std::istream & strm, char const * str_name, std::istream & strm, char const * str_name,
script_state * ss = nullptr, bool use_exceptions = false); script_state * ss = nullptr, bool use_exceptions = false,
local_level_decls const & lds = local_level_decls(),
local_expr_decls const & eds = local_expr_decls());
environment const & env() const { return m_env; } environment const & env() const { return m_env; }
io_state const & ios() const { return m_ios; } io_state const & ios() const { return m_ios; }
script_state * ss() const { return m_ss; } script_state * ss() const { return m_ss; }
local_level_decls const & get_local_level_decls() const { return m_local_level_decls; }
local_expr_decls const & get_local_expr_decls() const { return m_local_decls; }
/** \brief Return the current position information */ /** \brief Return the current position information */
pos_info pos() const { return pos_info(m_scanner.get_line(), m_scanner.get_pos()); } pos_info pos() const { return pos_info(m_scanner.get_line(), m_scanner.get_pos()); }
@ -195,9 +196,9 @@ public:
void add_local_expr(name const & n, expr const & e, binder_info const & bi = binder_info()); void add_local_expr(name const & n, expr const & e, binder_info const & bi = binder_info());
void add_local(expr const & t); void add_local(expr const & t);
/** \brief Position of the local level declaration named \c n in the sequence of local level decls. */ /** \brief Position of the local level declaration named \c n in the sequence of local level decls. */
optional<unsigned> get_local_level_index(name const & n) const; unsigned get_local_level_index(name const & n) const;
/** \brief Position of the local declaration named \c n in the sequence of local decls. */ /** \brief Position of the local declaration named \c n in the sequence of local decls. */
optional<unsigned> get_local_index(name const & n) const; unsigned get_local_index(name const & n) const;
/** \brief Return the local parameter named \c n */ /** \brief Return the local parameter named \c n */
optional<parameter> get_local(name const & n) const; optional<parameter> get_local(name const & n) const;
/** /**