diff --git a/src/library/local_context.cpp b/src/library/local_context.cpp index b24a38e26..84d7b8e64 100644 --- a/src/library/local_context.cpp +++ b/src/library/local_context.cpp @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ +#include #include "util/fresh_name.h" #include "library/local_context.h" @@ -18,12 +19,12 @@ void local_decl::cell::dealloc() { this->~cell(); get_local_decl_allocator().recycle(this); } -local_decl::cell::cell(name const & n, name const & pp_n, expr const & t, optional const & v, binder_info const & bi): - m_name(n), m_pp_name(pp_n), m_type(t), m_value(v), m_bi(bi), m_rc(1) {} +local_decl::cell::cell(unsigned idx, name const & n, name const & pp_n, expr const & t, optional const & v, binder_info const & bi): + m_name(n), m_pp_name(pp_n), m_type(t), m_value(v), m_bi(bi), m_idx(idx), m_rc(1) {} local_decl::local_decl():local_decl(*g_dummy_decl) {} -local_decl::local_decl(name const & n, name const & pp_n, expr const & t, optional const & v, binder_info const & bi) { - m_ptr = new (get_local_decl_allocator().allocate()) cell(n, pp_n, t, v, bi); +local_decl::local_decl(unsigned idx, name const & n, name const & pp_n, expr const & t, optional const & v, binder_info const & bi) { + m_ptr = new (get_local_decl_allocator().allocate()) cell(idx, n, pp_n, t, v, bi); } name mk_local_decl_name() { @@ -40,9 +41,11 @@ bool is_local_decl_ref(expr const & e) { expr local_context::mk_local_decl(name const & n, name const & ppn, expr const & type, optional const & value, binder_info const & bi) { lean_assert(!m_local_decl_map.contains(n)); - local_decl l(n, ppn, type, value, bi); - m_local_decls = cons(l, m_local_decls); - m_local_decl_map.insert(n, l); + unsigned idx = m_next_idx; + m_next_idx++; + local_decl l(idx, n, ppn, type, value, bi); + m_name2local_decl.insert(n, l); + m_idx2local_decl.insert(idx, l); return mk_local_ref(n); } @@ -66,24 +69,29 @@ expr local_context::mk_local_decl(name const & ppn, expr const & type, expr cons optional local_context::get_local_decl(expr const & e) { lean_assert(is_local_decl_ref(e)); - if (auto r = m_local_decl_map.find(mlocal_name(e))) + if (auto r = m_name2local_decl.find(mlocal_name(e))) return optional(*r); else return optional(); } void local_context::for_each(std::function const & fn) const { - m_local_decl_map.for_each([&](name const &, local_decl const & d) { fn(d); }); + m_idx2local_decl.for_each([&](unsigned, local_decl const & d) { fn(d); }); } optional local_context::find_if(std::function const & pred) const { // NOLINT - return m_local_decl_map.find_if([&](name const &, local_decl const & d) { return pred(d); }); + return m_idx2local_decl.find_if([&](unsigned, local_decl const & d) { return pred(d); }); +} + +void local_context::for_each_after(local_decl const & d, std::function const & fn) const { + m_idx2local_decl.for_each_greater(d.get_idx(), [&](unsigned, local_decl const & d) { return fn(d); }); } void initialize_local_context() { g_local_prefix = new name(name::mk_internal_unique_name()); g_dummy_type = new expr(mk_constant(name::mk_internal_unique_name())); - g_dummy_decl = new local_decl(name("__local_decl_for_default_constructor"), name("__local_decl_for_default_constructor"), + g_dummy_decl = new local_decl(std::numeric_limits::max(), + name("__local_decl_for_default_constructor"), name("__local_decl_for_default_constructor"), *g_dummy_type, optional(), binder_info()); } diff --git a/src/library/local_context.h b/src/library/local_context.h index 9ca452867..442a9a8e7 100644 --- a/src/library/local_context.h +++ b/src/library/local_context.h @@ -23,16 +23,19 @@ public: expr m_type; optional m_value; binder_info m_bi; + unsigned m_idx; MK_LEAN_RC(); // Declare m_rc counter void dealloc(); - cell(name const & n, name const & pp_n, expr const & t, optional const & v, binder_info const & bi); + cell(unsigned idx, name const & n, name const & pp_n, expr const & t, optional const & v, binder_info const & bi); }; private: cell * m_ptr; friend class local_context; + friend void initialize_local_context(); + local_decl(unsigned idx, name const & n, name const & pp_n, expr const & t, optional const & v, binder_info const & bi); + unsigned get_idx() const { return m_ptr->m_idx; } public: local_decl(); - local_decl(name const & n, name const & pp_n, expr const & t, optional const & v, binder_info const & bi); local_decl(local_decl const & s):m_ptr(s.m_ptr) { if (m_ptr) m_ptr->inc_ref(); } local_decl(local_decl && s):m_ptr(s.m_ptr) { s.m_ptr = nullptr; } ~local_decl() { if (m_ptr) m_ptr->dec_ref(); } @@ -51,17 +54,25 @@ public: bool is_local_decl_ref(expr const & e); class local_context { - name_map m_local_decl_map; - list m_local_decls; + typedef rb_map idx2local_decl; + unsigned m_next_idx; + name_map m_name2local_decl; + idx2local_decl m_idx2local_decl; expr mk_local_decl(name const & n, name const & ppn, expr const & type, optional const & value, binder_info const & bi); public: + local_context():m_next_idx(0) {} expr mk_local_decl(expr const & type, binder_info const & bi = binder_info()); expr mk_local_decl(expr const & type, expr const & value); expr mk_local_decl(name const & ppn, expr const & type, binder_info const & bi = binder_info()); expr mk_local_decl(name const & ppn, expr const & type, expr const & value); + /** \brief Return the local declarations for the given reference. + \pre is_local_decl_ref(e) */ optional get_local_decl(expr const & e); + /** \brief Traverse local declarations based on the order they were created */ void for_each(std::function const & fn) const; optional find_if(std::function const & pred) const; // NOLINT + /** \brief Execute fn for each local declaration created after \c d. */ + void for_each_after(local_decl const & d, std::function const & fn) const; }; void initialize_local_context();