diff --git a/src/kernel/CMakeLists.txt b/src/kernel/CMakeLists.txt index 44b99576e..334ec97f0 100644 --- a/src/kernel/CMakeLists.txt +++ b/src/kernel/CMakeLists.txt @@ -1,9 +1,8 @@ add_library(kernel level.cpp diff_cnstrs.cpp expr.cpp expr_eq_fn.cpp for_each_fn.cpp replace_fn.cpp free_vars.cpp abstract.cpp instantiate.cpp context.cpp formatter.cpp max_sharing.cpp -definition.cpp replace_visitor.cpp environment.cpp -justification.cpp pos_info_provider.cpp metavar.cpp converter.cpp -constraint.cpp type_checker.cpp error_msgs.cpp kernel_exception.cpp -) +definition.cpp replace_visitor.cpp environment.cpp justification.cpp +pos_info_provider.cpp metavar.cpp converter.cpp constraint.cpp +type_checker.cpp error_msgs.cpp kernel_exception.cpp ) target_link_libraries(kernel ${LEAN_LIBS}) diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index b42a4a3ac..3054f6b96 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -181,20 +181,16 @@ void expr_let::dealloc(buffer & todelete) { expr_let::~expr_let() {} // Macro definition -int macro_definition::push_lua(lua_State *) const { return 0; } // NOLINT -bool macro_definition::operator==(macro_definition const & other) const { return typeid(*this) == typeid(other); } -bool macro_definition::operator<(macro_definition const & other) const { - if (get_name() == other.get_name()) - return lt(other); - else - return get_name() < other.get_name(); -} -format macro_definition::pp(formatter const &, options const &) const { return format(get_name()); } -void macro_definition::display(std::ostream & out) const { out << get_name(); } -bool macro_definition::is_atomic_pp(bool, bool) const { return true; } -unsigned macro_definition::hash() const { return get_name().hash(); } +bool macro_definition_cell::lt(macro_definition_cell const &) const { return false; } +bool macro_definition_cell::operator==(macro_definition_cell const & other) const { return typeid(*this) == typeid(other); } +unsigned macro_definition_cell::trust_level() const { return 0; } -typedef std::unordered_map macro_readers; +format macro_definition_cell::pp(formatter const &, options const &) const { return format(get_name()); } +void macro_definition_cell::display(std::ostream & out) const { out << get_name(); } +bool macro_definition_cell::is_atomic_pp(bool, bool) const { return true; } +unsigned macro_definition_cell::hash() const { return get_name().hash(); } + +typedef std::unordered_map macro_readers; static std::unique_ptr g_macro_readers; macro_readers & get_macro_readers() { if (!g_macro_readers) @@ -202,7 +198,7 @@ macro_readers & get_macro_readers() { return *(g_macro_readers.get()); } -void macro_definition::register_deserializer(std::string const & k, macro_definition::reader rd) { +void macro_definition_cell::register_deserializer(std::string const & k, macro_definition_cell::reader rd) { macro_readers & readers = get_macro_readers(); lean_assert(readers.find(k) == readers.end()); readers[k] = rd; @@ -215,6 +211,19 @@ static expr read_macro_definition(deserializer & d, unsigned num, expr const * a return it->second(d, num, args); } +macro_definition::macro_definition(macro_definition_cell * ptr):m_ptr(ptr) { lean_assert(m_ptr); m_ptr->inc_ref(); } +macro_definition::macro_definition(macro_definition const & s):m_ptr(s.m_ptr) { if (m_ptr) m_ptr->inc_ref(); } +macro_definition::macro_definition(macro_definition && s):m_ptr(s.m_ptr) { s.m_ptr = nullptr; } +macro_definition::~macro_definition() { if (m_ptr) m_ptr->dec_ref(); } +macro_definition & macro_definition::operator=(macro_definition const & s) { LEAN_COPY_REF(s); } +macro_definition & macro_definition::operator=(macro_definition && s) { LEAN_MOVE_REF(s); } +bool macro_definition::operator<(macro_definition const & other) const { + if (get_name() == other.get_name()) + return m_ptr->lt(*other.m_ptr); + else + return get_name() < other.get_name(); +} + static unsigned max_depth(unsigned num, expr const * args) { unsigned r = 0; for (unsigned i = 0; i < num; i++) { @@ -235,9 +244,9 @@ static unsigned get_free_var_range(unsigned num, expr const * args) { return r; } -expr_macro::expr_macro(macro_definition * m, unsigned num, expr const * args): +expr_macro::expr_macro(macro_definition const & m, unsigned num, expr const * args): expr_composite(expr_kind::Macro, - lean::hash(num, [&](unsigned i) { return args[i].hash(); }, m->hash()), + lean::hash(num, [&](unsigned i) { return args[i].hash(); }, m.hash()), std::any_of(args, args+num, [](expr const & e) { return e.has_metavar(); }), std::any_of(args, args+num, [](expr const & e) { return e.has_local(); }), std::any_of(args, args+num, [](expr const & e) { return e.has_param_univ(); }), @@ -245,13 +254,11 @@ expr_macro::expr_macro(macro_definition * m, unsigned num, expr const * args): get_free_var_range(num, args)), m_definition(m), m_num_args(num) { - m_definition->inc_ref(); m_args = new expr[num]; for (unsigned i = 0; i < m_num_args; i++) m_args[i] = args[i]; } void expr_macro::dealloc(buffer & todelete) { - m_definition->dec_ref(); for (unsigned i = 0; i < m_num_args; i++) dec_ref(m_args[i], todelete); delete(this); } diff --git a/src/kernel/expr.h b/src/kernel/expr.h index 3fed9dcf5..90c235957 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -121,7 +121,7 @@ public: friend expr mk_proj(bool fst, expr const & p); friend expr mk_binder(expr_kind k, name const & n, expr const & t, expr const & e, expr_binder_info const & i); friend expr mk_let(name const & n, expr const & t, expr const & v, expr const & e); - friend expr mk_macro(macro_definition * m, unsigned num, expr const * args); + friend expr mk_macro(macro_definition const & m, unsigned num, expr const * args); friend bool is_eqp(expr const & a, expr const & b) { return a.m_ptr == b.m_ptr; } // Overloaded operator() can be used to create applications @@ -262,29 +262,26 @@ public: class formatter; -/** \brief Base class for macro definition attachments */ -class macro_definition { +/** \brief Abstract class for macro_definitions */ +class macro_definition_cell { +protected: void dealloc() { delete this; } MK_LEAN_RC(); -protected: /** - \brief Auxiliary method used for implementing a total order on macro - attachments. It is invoked by operator<, and it is only invoked when - get_name() == other.get_name() + \brief Auxiliary method used for implementing a total order on macro + attachments. It is invoked by operator<, and it is only invoked when + get_name() == other.get_name() */ - virtual bool lt(macro_definition const &) const { return false; } + virtual bool lt(macro_definition_cell const &) const; public: - macro_definition():m_rc(0) {} - virtual ~macro_definition() {} + macro_definition_cell():m_rc(0) {} + virtual ~macro_definition_cell() {} virtual name get_name() const = 0; virtual expr get_type(unsigned num, expr const * args, expr const * arg_types, extension_context & ctx) const = 0; virtual optional expand1(unsigned num, expr const * args, extension_context & ctx) const = 0; virtual optional expand(unsigned num, expr const * args, extension_context & ctx) const = 0; - virtual unsigned trust_level() const { return 0; } - virtual int push_lua(lua_State * L) const; - virtual bool operator==(macro_definition const & other) const; - bool operator!=(macro_definition const & other) const { return !operator==(other); } - bool operator<(macro_definition const & other) const; + virtual unsigned trust_level() const; + virtual bool operator==(macro_definition_cell const & other) const; virtual void display(std::ostream & out) const; virtual format pp(formatter const & fmt, options const & opts) const; virtual bool is_atomic_pp(bool unicode, bool coercion) const; @@ -293,24 +290,56 @@ public: typedef std::function reader; static void register_deserializer(std::string const & k, reader rd); struct register_deserializer_fn { - register_deserializer_fn(std::string const & k, macro_definition::reader rd) { macro_definition::register_deserializer(k, rd); } + register_deserializer_fn(std::string const & k, macro_definition_cell::reader rd) { + macro_definition_cell::register_deserializer(k, rd); + } }; }; +/** \brief Smart pointer for macro definitions */ +class macro_definition { +public: + macro_definition_cell * m_ptr; +public: + explicit macro_definition(macro_definition_cell * ptr); + macro_definition(macro_definition const & s); + macro_definition(macro_definition && s); + ~macro_definition(); + + macro_definition & operator=(macro_definition const & s); + macro_definition & operator=(macro_definition && s); + + name get_name() const { return m_ptr->get_name(); } + expr get_type(unsigned num, expr const * args, expr const * arg_types, extension_context & ctx) const { + return m_ptr->get_type(num, args, arg_types, ctx); + } + optional expand1(unsigned num, expr const * args, extension_context & ctx) const { return m_ptr->expand1(num, args, ctx); } + optional expand(unsigned num, expr const * args, extension_context & ctx) const { return m_ptr->expand(num, args, ctx); } + unsigned trust_level() const { return m_ptr->trust_level(); } + bool operator==(macro_definition const & other) const { return m_ptr->operator==(*other.m_ptr); } + bool operator!=(macro_definition const & other) const { return !operator==(other); } + bool operator<(macro_definition const & other) const; + void display(std::ostream & out) const { return m_ptr->display(out); } + format pp(formatter const & fmt, options const & opts) const { return m_ptr->pp(fmt, opts); } + bool is_atomic_pp(bool unicode, bool coercion) const { return m_ptr->is_atomic_pp(unicode, coercion); } + unsigned hash() const { return m_ptr->hash(); } + void write(serializer & s) const { return m_ptr->write(s); } +}; + /** \brief Macro attachments */ class expr_macro : public expr_composite { - macro_definition * m_definition; - unsigned m_num_args; - expr * m_args; + macro_definition m_definition; + unsigned m_num_args; + expr * m_args; friend class expr_cell; friend expr copy(expr const & a); friend expr update_macro(expr const & e, unsigned num, expr const * args); void dealloc(buffer & todelete); public: - expr_macro(macro_definition * v, unsigned num, expr const * args); + expr_macro(macro_definition const & v, unsigned num, expr const * args); ~expr_macro(); - macro_definition const & get_def() const { return *m_definition; } + macro_definition const & get_def() const { return m_definition; } expr const * get_args() const { return m_args; } expr const & get_arg(unsigned idx) const { lean_assert(idx < m_num_args); return m_args[idx]; } unsigned get_num_args() const { return m_num_args; } @@ -357,7 +386,7 @@ inline expr Var(unsigned idx) { return mk_var(idx); } inline expr mk_constant(name const & n, levels const & ls) { return expr(new expr_const(n, ls)); } inline expr mk_constant(name const & n) { return mk_constant(n, levels()); } inline expr Const(name const & n) { return mk_constant(n); } -inline expr mk_macro(macro_definition * m, unsigned num = 0, expr const * args = nullptr) { return expr(new expr_macro(m, num, args)); } +inline expr mk_macro(macro_definition const & m, unsigned num = 0, expr const * args = nullptr) { return expr(new expr_macro(m, num, args)); } inline expr mk_mlocal(bool is_meta, name const & n, expr const & t) { return expr(new expr_mlocal(is_meta, n, t)); } inline expr mk_metavar(name const & n, expr const & t) { return mk_mlocal(true, n, t); } inline expr mk_local(name const & n, expr const & t) { return mk_mlocal(false, n, t); }