feat(kernel/environment): add environment extension objects, the environment can be extended with frontend specific objects

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-11-06 19:21:22 -08:00
parent 6cb282d70b
commit 80e23f98c7
3 changed files with 176 additions and 0 deletions

View file

@ -9,6 +9,7 @@ Author: Leonardo de Moura
#include <atomic> #include <atomic>
#include <tuple> #include <tuple>
#include <unordered_map> #include <unordered_map>
#include <mutex>
#include "util/safe_arith.h" #include "util/safe_arith.h"
#include "kernel/for_each.h" #include "kernel/for_each.h"
#include "kernel/kernel_exception.h" #include "kernel/kernel_exception.h"
@ -17,6 +18,35 @@ Author: Leonardo de Moura
#include "kernel/normalizer.h" #include "kernel/normalizer.h"
namespace lean { namespace lean {
class extension_factory {
std::vector<environment::mk_extension> m_makers;
std::mutex m_makers_mutex;
public:
unsigned register_extension(environment::mk_extension mk) {
std::lock_guard<std::mutex> lock(m_makers_mutex);
unsigned r = m_makers.size();
m_makers.push_back(mk);
return r;
}
std::unique_ptr<environment::extension> mk(unsigned extid) {
std::lock_guard<std::mutex> lock(m_makers_mutex);
return m_makers[extid]();
}
};
static std::unique_ptr<extension_factory> g_extension_factory;
static extension_factory & get_extension_factory() {
if (!g_extension_factory)
g_extension_factory.reset(new extension_factory());
return *g_extension_factory;
}
unsigned environment::register_extension(mk_extension mk) {
return get_extension_factory().register_extension(mk);
}
/** \brief Implementation of the Lean environment. */ /** \brief Implementation of the Lean environment. */
struct environment::imp { struct environment::imp {
// Remark: only named objects are stored in the dictionary. // Remark: only named objects are stored in the dictionary.
@ -37,6 +67,23 @@ struct environment::imp {
object_dictionary m_object_dictionary; object_dictionary m_object_dictionary;
type_checker m_type_checker; type_checker m_type_checker;
std::vector<std::unique_ptr<extension>> m_extensions;
friend class extension;
extension & get_extension_core(unsigned extid, environment const & env) {
if (has_children())
throw read_only_environment_exception(env);
if (extid >= m_extensions.size())
m_extensions.resize(extid+1);
if (!m_extensions[extid]) {
std::unique_ptr<extension> ext = get_extension_factory().mk(extid);
ext->m_extid = extid;
ext->m_env = this;
m_extensions[extid].swap(ext);
}
return *(m_extensions[extid].get());
}
unsigned get_max_weight(expr const & e) { unsigned get_max_weight(expr const & e) {
unsigned w = 0; unsigned w = 0;
auto proc = [&](expr const & c, unsigned) { auto proc = [&](expr const & c, unsigned) {
@ -484,4 +531,31 @@ void environment::display(std::ostream & out) const {
void environment::set_interrupt(bool flag) { void environment::set_interrupt(bool flag) {
m_ptr->set_interrupt(flag); m_ptr->set_interrupt(flag);
} }
environment::extension & environment::get_extension_core(unsigned extid) const {
return m_ptr->get_extension_core(extid, *this);
}
environment::extension::extension():
m_env(nullptr),
m_extid(0) {
}
environment::extension::~extension() {
}
environment::extension const * environment::extension::get_parent_core() const {
if (m_env == nullptr)
return nullptr;
imp * parent = m_env->m_parent.get();
while (parent) {
if (m_extid < parent->m_extensions.size()) {
extension * ext = parent->m_extensions[m_extid].get();
if (ext)
return ext;
}
parent = parent->m_parent.get();
}
return nullptr;
}
} }

View file

@ -211,5 +211,60 @@ public:
void set_interrupt(bool flag); void set_interrupt(bool flag);
void interrupt() { set_interrupt(true); } void interrupt() { set_interrupt(true); }
void reset_interrupt() { set_interrupt(false); } void reset_interrupt() { set_interrupt(false); }
/**
\brief Frontend can store data in environment extensions.
Each extension is associated with a unique token/id.
This token allows the frontend to retrieve/store an extension object
in the environment
*/
class extension {
friend class imp;
imp * m_env;
unsigned m_extid; // extension id
extension const * get_parent_core() const;
public:
extension();
virtual ~extension();
/**
\brief Return a constant reference for a parent extension,
and a nullptr if there is no parent/ancestor, or if the
parent/ancestor has an extension.
*/
template<typename Ext> Ext const * get_parent() const {
extension const * ext = get_parent_core();
lean_assert(!ext || dynamic_cast<Ext const *>(ext) != nullptr);
return static_cast<Ext const *>(ext);
}
};
/**
\brief Register an environment extension. Every environment
object will contain this extension. The funciton mk creates a
new instance of the extension. The extension object can be
retrieved using the token (unsigned integer) returned by this
method.
\remark The extension objects are created on demand.
\see get_extension
*/
typedef std::unique_ptr<extension> (*mk_extension)();
static unsigned register_extension(mk_extension mk);
private:
extension & get_extension_core(unsigned extid) const;
public:
/**
\brief Retrieve the extension associated with the token \c extid.
The token is the value returned by \c register_extension.
*/
template<typename Ext>
Ext & get_extension(unsigned extid) const {
extension & ext = get_extension_core(extid);
lean_assert(dynamic_cast<Ext*>(&ext) != nullptr);
return static_cast<Ext&>(ext);
}
}; };
} }

View file

@ -225,6 +225,51 @@ static void tst10() {
lean_assert(env.get_object("d").get_weight() == 3); lean_assert(env.get_object("d").get_weight() == 3);
} }
struct my_extension : public environment::extension {
unsigned m_value1;
unsigned m_value2;
my_extension():m_value1(0), m_value2(0) {}
};
struct my_extension_reg {
unsigned m_extid;
my_extension_reg() {
m_extid = environment::register_extension([](){ return std::unique_ptr<environment::extension>(new my_extension()); });
}
};
static my_extension_reg R;
static void tst11() {
unsigned extid = R.m_extid;
environment env;
my_extension & ext = env.get_extension<my_extension>(extid);
ext.m_value1 = 10;
ext.m_value2 = 20;
my_extension & ext2 = env.get_extension<my_extension>(extid);
lean_assert(ext2.m_value1 == 10);
lean_assert(ext2.m_value2 == 20);
environment child = env.mk_child();
my_extension & ext3 = child.get_extension<my_extension>(extid);
lean_assert(ext3.m_value1 == 0);
lean_assert(ext3.m_value2 == 0);
my_extension const * ext4 = ext3.get_parent<my_extension>();
lean_assert(ext4);
lean_assert(ext4->m_value1 == 10);
lean_assert(ext4->m_value2 == 20);
lean_assert(ext4->get_parent<my_extension>() == nullptr);
}
static void tst12() {
unsigned extid = R.m_extid;
environment env;
environment child = env.mk_child();
my_extension & ext = child.get_extension<my_extension>(extid);
lean_assert(ext.m_value1 == 0);
lean_assert(ext.m_value2 == 0);
lean_assert(ext.get_parent<my_extension>() == nullptr);
}
int main() { int main() {
enable_trace("is_convertible"); enable_trace("is_convertible");
tst1(); tst1();
@ -237,5 +282,7 @@ int main() {
tst8(); tst8();
tst9(); tst9();
tst10(); tst10();
tst11();
tst12();
return has_violations() ? 1 : 0; return has_violations() ? 1 : 0;
} }