From 80e23f98c783b59f256a71c713ac94622a3f0d0f Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 6 Nov 2013 19:21:22 -0800 Subject: [PATCH] feat(kernel/environment): add environment extension objects, the environment can be extended with frontend specific objects Signed-off-by: Leonardo de Moura --- src/kernel/environment.cpp | 74 ++++++++++++++++++++++++++++++++ src/kernel/environment.h | 55 ++++++++++++++++++++++++ src/tests/kernel/environment.cpp | 47 ++++++++++++++++++++ 3 files changed, 176 insertions(+) diff --git a/src/kernel/environment.cpp b/src/kernel/environment.cpp index 1aadd6e1b..2bbc5ddb2 100644 --- a/src/kernel/environment.cpp +++ b/src/kernel/environment.cpp @@ -9,6 +9,7 @@ Author: Leonardo de Moura #include #include #include +#include #include "util/safe_arith.h" #include "kernel/for_each.h" #include "kernel/kernel_exception.h" @@ -17,6 +18,35 @@ Author: Leonardo de Moura #include "kernel/normalizer.h" namespace lean { + +class extension_factory { + std::vector m_makers; + std::mutex m_makers_mutex; +public: + unsigned register_extension(environment::mk_extension mk) { + std::lock_guard lock(m_makers_mutex); + unsigned r = m_makers.size(); + m_makers.push_back(mk); + return r; + } + + std::unique_ptr mk(unsigned extid) { + std::lock_guard lock(m_makers_mutex); + return m_makers[extid](); + } +}; + +static std::unique_ptr g_extension_factory; +static extension_factory & get_extension_factory() { + if (!g_extension_factory) + g_extension_factory.reset(new extension_factory()); + return *g_extension_factory; +} + +unsigned environment::register_extension(mk_extension mk) { + return get_extension_factory().register_extension(mk); +} + /** \brief Implementation of the Lean environment. */ struct environment::imp { // Remark: only named objects are stored in the dictionary. @@ -37,6 +67,23 @@ struct environment::imp { object_dictionary m_object_dictionary; type_checker m_type_checker; + std::vector> m_extensions; + friend class extension; + + extension & get_extension_core(unsigned extid, environment const & env) { + if (has_children()) + throw read_only_environment_exception(env); + if (extid >= m_extensions.size()) + m_extensions.resize(extid+1); + if (!m_extensions[extid]) { + std::unique_ptr ext = get_extension_factory().mk(extid); + ext->m_extid = extid; + ext->m_env = this; + m_extensions[extid].swap(ext); + } + return *(m_extensions[extid].get()); + } + unsigned get_max_weight(expr const & e) { unsigned w = 0; auto proc = [&](expr const & c, unsigned) { @@ -484,4 +531,31 @@ void environment::display(std::ostream & out) const { void environment::set_interrupt(bool flag) { m_ptr->set_interrupt(flag); } + +environment::extension & environment::get_extension_core(unsigned extid) const { + return m_ptr->get_extension_core(extid, *this); +} + +environment::extension::extension(): + m_env(nullptr), + m_extid(0) { +} + +environment::extension::~extension() { +} + +environment::extension const * environment::extension::get_parent_core() const { + if (m_env == nullptr) + return nullptr; + imp * parent = m_env->m_parent.get(); + while (parent) { + if (m_extid < parent->m_extensions.size()) { + extension * ext = parent->m_extensions[m_extid].get(); + if (ext) + return ext; + } + parent = parent->m_parent.get(); + } + return nullptr; +} } diff --git a/src/kernel/environment.h b/src/kernel/environment.h index 0e3b11f51..bc419f095 100644 --- a/src/kernel/environment.h +++ b/src/kernel/environment.h @@ -211,5 +211,60 @@ public: void set_interrupt(bool flag); void interrupt() { set_interrupt(true); } void reset_interrupt() { set_interrupt(false); } + + /** + \brief Frontend can store data in environment extensions. + Each extension is associated with a unique token/id. + This token allows the frontend to retrieve/store an extension object + in the environment + */ + class extension { + friend class imp; + imp * m_env; + unsigned m_extid; // extension id + extension const * get_parent_core() const; + public: + extension(); + virtual ~extension(); + /** + \brief Return a constant reference for a parent extension, + and a nullptr if there is no parent/ancestor, or if the + parent/ancestor has an extension. + */ + template Ext const * get_parent() const { + extension const * ext = get_parent_core(); + lean_assert(!ext || dynamic_cast(ext) != nullptr); + return static_cast(ext); + } + }; + + /** + \brief Register an environment extension. Every environment + object will contain this extension. The funciton mk creates a + new instance of the extension. The extension object can be + retrieved using the token (unsigned integer) returned by this + method. + + \remark The extension objects are created on demand. + + \see get_extension + */ + typedef std::unique_ptr (*mk_extension)(); + static unsigned register_extension(mk_extension mk); + +private: + extension & get_extension_core(unsigned extid) const; + +public: + /** + \brief Retrieve the extension associated with the token \c extid. + The token is the value returned by \c register_extension. + */ + template + Ext & get_extension(unsigned extid) const { + extension & ext = get_extension_core(extid); + lean_assert(dynamic_cast(&ext) != nullptr); + return static_cast(ext); + } }; } diff --git a/src/tests/kernel/environment.cpp b/src/tests/kernel/environment.cpp index eb77e782e..153e5f928 100644 --- a/src/tests/kernel/environment.cpp +++ b/src/tests/kernel/environment.cpp @@ -225,6 +225,51 @@ static void tst10() { lean_assert(env.get_object("d").get_weight() == 3); } +struct my_extension : public environment::extension { + unsigned m_value1; + unsigned m_value2; + my_extension():m_value1(0), m_value2(0) {} +}; + +struct my_extension_reg { + unsigned m_extid; + my_extension_reg() { + m_extid = environment::register_extension([](){ return std::unique_ptr(new my_extension()); }); + } +}; + +static my_extension_reg R; + +static void tst11() { + unsigned extid = R.m_extid; + environment env; + my_extension & ext = env.get_extension(extid); + ext.m_value1 = 10; + ext.m_value2 = 20; + my_extension & ext2 = env.get_extension(extid); + lean_assert(ext2.m_value1 == 10); + lean_assert(ext2.m_value2 == 20); + environment child = env.mk_child(); + my_extension & ext3 = child.get_extension(extid); + lean_assert(ext3.m_value1 == 0); + lean_assert(ext3.m_value2 == 0); + my_extension const * ext4 = ext3.get_parent(); + lean_assert(ext4); + lean_assert(ext4->m_value1 == 10); + lean_assert(ext4->m_value2 == 20); + lean_assert(ext4->get_parent() == nullptr); +} + +static void tst12() { + unsigned extid = R.m_extid; + environment env; + environment child = env.mk_child(); + my_extension & ext = child.get_extension(extid); + lean_assert(ext.m_value1 == 0); + lean_assert(ext.m_value2 == 0); + lean_assert(ext.get_parent() == nullptr); +} + int main() { enable_trace("is_convertible"); tst1(); @@ -237,5 +282,7 @@ int main() { tst8(); tst9(); tst10(); + tst11(); + tst12(); return has_violations() ? 1 : 0; }