diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index 5d433c8dc..7c6bc3e40 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -179,21 +179,21 @@ void expr_let::dealloc(buffer & todelete) { } expr_let::~expr_let() {} -// Macro attachment -int macro::push_lua(lua_State *) const { return 0; } // NOLINT -void macro::display(std::ostream & out) const { out << get_name(); } -bool macro::operator==(macro const & other) const { return typeid(*this) == typeid(other); } -bool macro::operator<(macro const & other) const { +// Macro definition +int macro_definition::push_lua(lua_State *) const { return 0; } // NOLINT +bool macro_definition::operator==(macro_definition const & other) const { return typeid(*this) == typeid(other); } +bool macro_definition::operator<(macro_definition const & other) const { if (get_name() == other.get_name()) return lt(other); else return get_name() < other.get_name(); } -format macro::pp(formatter const &, options const &) const { return format(get_name()); } -bool macro::is_atomic_pp(bool, bool) const { return true; } -unsigned macro::hash() const { return get_name().hash(); } +format macro_definition::pp(formatter const &, options const &) const { return format(get_name()); } +void macro_definition::display(std::ostream & out) const { out << get_name(); } +bool macro_definition::is_atomic_pp(bool, bool) const { return true; } +unsigned macro_definition::hash() const { return get_name().hash(); } -typedef std::unordered_map macro_readers; +typedef std::unordered_map macro_readers; static std::unique_ptr g_macro_readers; macro_readers & get_macro_readers() { if (!g_macro_readers) @@ -201,26 +201,61 @@ macro_readers & get_macro_readers() { return *(g_macro_readers.get()); } -void macro::register_deserializer(std::string const & k, macro::reader rd) { +void macro_definition::register_deserializer(std::string const & k, macro_definition::reader rd) { macro_readers & readers = get_macro_readers(); lean_assert(readers.find(k) == readers.end()); readers[k] = rd; } -static expr read_macro(deserializer & d) { +static expr read_macro_definition(deserializer & d, unsigned num, expr const * args) { auto k = d.read_string(); macro_readers & readers = get_macro_readers(); auto it = readers.find(k); lean_assert(it != readers.end()); - return it->second(d); + return it->second(d, num, args); } -expr_macro::expr_macro(macro * m): - expr_cell(expr_kind::Macro, m->hash(), false, false, false), - m_macro(m) { - m_macro->inc_ref(); +static unsigned max_depth(unsigned num, expr const * args) { + unsigned r = 0; + for (unsigned i = 0; i < num; i++) { + unsigned d = get_depth(args[i]); + if (d > r) + r = d; + } + return r; +} + +static unsigned get_free_var_range(unsigned num, expr const * args) { + unsigned r = 0; + for (unsigned i = 0; i < num; i++) { + unsigned d = get_free_var_range(args[i]); + if (d > r) + r = d; + } + return r; +} + +expr_macro::expr_macro(macro_definition * m, unsigned num, expr const * args): + expr_composite(expr_kind::Macro, + lean::hash(num, [&](unsigned i) { return args[i].hash(); }, m->hash()), + std::any_of(args, args+num, [](expr const & e) { return e.has_metavar(); }), + std::any_of(args, args+num, [](expr const & e) { return e.has_local(); }), + std::any_of(args, args+num, [](expr const & e) { return e.has_param_univ(); }), + max_depth(num, args) + 1, + get_free_var_range(num, args)), + m_definition(m), + m_num_args(num) { + m_definition->inc_ref(); + m_args = new expr[num]; + for (unsigned i = 0; i < m_num_args; i++) + m_args[i] = args[i]; +} +void expr_macro::dealloc(buffer & todelete) { + m_definition->dec_ref(); + for (unsigned i = 0; i < m_num_args; i++) dec_ref(m_args[i], todelete); + delete(this); } expr_macro::~expr_macro() { - m_macro->dec_ref(); + delete[] m_args; } void expr_cell::dealloc() { @@ -233,7 +268,7 @@ void expr_cell::dealloc() { lean_assert(it->get_rc() == 0); switch (it->kind()) { case expr_kind::Var: delete static_cast(it); break; - case expr_kind::Macro: delete static_cast(it); break; + case expr_kind::Macro: static_cast(it)->dealloc(todo); break; case expr_kind::Meta: case expr_kind::Local: static_cast(it)->dealloc(todo); break; case expr_kind::Constant: delete static_cast(it); break; @@ -290,9 +325,9 @@ expr mk_Type() { return Type; } unsigned get_depth(expr const & e) { switch (e.kind()) { case expr_kind::Var: case expr_kind::Constant: case expr_kind::Sort: - case expr_kind::Meta: case expr_kind::Local: case expr_kind::Macro: + case expr_kind::Meta: case expr_kind::Local: return 1; - case expr_kind::Lambda: case expr_kind::Pi: + case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Macro: case expr_kind::App: case expr_kind::Let: return static_cast(e.raw())->m_depth; } @@ -303,12 +338,13 @@ unsigned get_free_var_range(expr const & e) { switch (e.kind()) { case expr_kind::Var: return var_idx(e) + 1; - case expr_kind::Constant: case expr_kind::Sort: case expr_kind::Macro: + case expr_kind::Constant: case expr_kind::Sort: return 0; case expr_kind::Meta: case expr_kind::Local: return get_free_var_range(mlocal_type(e)); case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::App: case expr_kind::Let: + case expr_kind::Macro: return static_cast(e.raw())->m_free_var_range; } lean_unreachable(); // LCOV_EXCL_LINE @@ -370,11 +406,26 @@ expr update_constant(expr const & e, levels const & new_levels) { return e; } +expr update_macro(expr const & e, unsigned num, expr const * args) { + if (num == macro_num_args(e)) { + unsigned i = 0; + for (i = 0; i < num; i++) { + if (!is_eqp(macro_arg(e, i), args[i])) + break; + } + if (i == num) + return e; + } + return mk_macro(to_macro(e)->m_definition, num, args); +} + bool is_atomic(expr const & e) { switch (e.kind()) { case expr_kind::Constant: case expr_kind::Sort: - case expr_kind::Macro: case expr_kind::Var: + case expr_kind::Var: return true; + case expr_kind::Macro: + return to_macro(e)->get_num_args() == 0; case expr_kind::App: case expr_kind::Let: case expr_kind::Meta: case expr_kind::Local: case expr_kind::Lambda: case expr_kind::Pi: @@ -399,7 +450,7 @@ expr copy(expr const & a) { case expr_kind::Var: return mk_var(var_idx(a)); case expr_kind::Constant: return mk_constant(const_name(a), const_level_params(a)); case expr_kind::Sort: return mk_sort(sort_level(a)); - case expr_kind::Macro: return mk_macro(static_cast(a.raw())->m_macro); + case expr_kind::Macro: return mk_macro(to_macro(a)->m_definition, macro_num_args(a), macro_args(a)); case expr_kind::App: return mk_app(app_fn(a), app_arg(a)); case expr_kind::Lambda: return mk_lambda(binder_name(a), binder_domain(a), binder_body(a)); case expr_kind::Pi: return mk_pi(binder_name(a), binder_domain(a), binder_body(a)); @@ -429,7 +480,11 @@ class expr_serializer : public object_serializer args; + for (unsigned i = 0; i < n; i++) { + args.push_back(read()); + } + return read_macro_definition(d, args.size(), args.data()); + } case expr_kind::App: { expr f = read(); return mk_app(f, read()); diff --git a/src/kernel/expr.h b/src/kernel/expr.h index 53d034d8f..c38f88ab3 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -70,7 +70,7 @@ public: bool has_param_univ() const { return m_has_param_univ; } }; -class macro; +class macro_definition; /** \brief Exprs for encoding formulas/expressions, types and proofs. @@ -119,7 +119,7 @@ public: friend expr mk_proj(bool fst, expr const & p); friend expr mk_binder(expr_kind k, name const & n, expr const & t, expr const & e); friend expr mk_let(name const & n, expr const & t, expr const & v, expr const & e); - friend expr mk_macro(macro * m); + friend expr mk_macro(macro_definition * m, unsigned num, expr const * args); friend bool is_eqp(expr const & a, expr const & b) { return a.m_ptr == b.m_ptr; } // Overloaded operator() can be used to create applications @@ -245,8 +245,8 @@ public: class formatter; -/** \brief Base class for macro attachments */ -class macro { +/** \brief Base class for macro definition attachments */ +class macro_definition { void dealloc() { delete this; } MK_LEAN_RC(); protected: @@ -255,39 +255,48 @@ protected: attachments. It is invoked by operator<, and it is only invoked when get_name() == other.get_name() */ - virtual bool lt(macro const &) const { return false; } + virtual bool lt(macro_definition const &) const { return false; } public: - macro():m_rc(0) {} - virtual ~macro() {} + macro_definition():m_rc(0) {} + virtual ~macro_definition() {} virtual name get_name() const = 0; - virtual expr get_type(unsigned num_args, expr const * args, expr const * arg_types) const = 0; - virtual optional expand1(unsigned num_args, expr const * args) const = 0; - virtual optional expand(unsigned num_args, expr const * args) const = 0; + virtual expr get_type(unsigned num, expr const * args, expr const * arg_types) const = 0; + virtual optional expand1(unsigned num, expr const * args) const = 0; + virtual optional expand(unsigned num, expr const * args) const = 0; virtual unsigned trust_level() const = 0; virtual int push_lua(lua_State * L) const; - virtual bool operator==(macro const & other) const; - bool operator<(macro const & other) const; + virtual bool operator==(macro_definition const & other) const; + bool operator!=(macro_definition const & other) const { return !operator==(other); } + bool operator<(macro_definition const & other) const; virtual void display(std::ostream & out) const; virtual format pp(formatter const & fmt, options const & opts) const; virtual bool is_atomic_pp(bool unicode, bool coercion) const; virtual unsigned hash() const; virtual void write(serializer & s) const = 0; - typedef std::function reader; + typedef std::function reader; static void register_deserializer(std::string const & k, reader rd); struct register_deserializer_fn { - register_deserializer_fn(std::string const & k, macro::reader rd) { macro::register_deserializer(k, rd); } + register_deserializer_fn(std::string const & k, macro_definition::reader rd) { macro_definition::register_deserializer(k, rd); } }; }; /** \brief Macro attachments */ -class expr_macro : public expr_cell { - macro * m_macro; +class expr_macro : public expr_composite { + macro_definition * m_definition; + unsigned m_num_args; + expr * m_args; + friend class expr_cell; friend expr copy(expr const & a); + friend expr update_macro(expr const & e, unsigned num, expr const * args); + void dealloc(buffer & todelete); public: - expr_macro(macro * v); + expr_macro(macro_definition * v, unsigned num, expr const * args); ~expr_macro(); - macro const & get_macro() const { return *m_macro; } + macro_definition const & get_def() const { return *m_definition; } + expr const * get_args() const { return m_args; } + expr const & get_arg(unsigned idx) const { lean_assert(idx < m_num_args); return m_args[idx]; } + unsigned get_num_args() const { return m_num_args; } }; // ======================================= @@ -331,7 +340,7 @@ inline expr Var(unsigned idx) { return mk_var(idx); } inline expr mk_constant(name const & n, levels const & ls) { return expr(new expr_const(n, ls)); } inline expr mk_constant(name const & n) { return mk_constant(n, levels()); } inline expr Const(name const & n) { return mk_constant(n); } -inline expr mk_macro(macro * m) { return expr(new expr_macro(m)); } +inline expr mk_macro(macro_definition * m, unsigned num = 0, expr const * args = nullptr) { return expr(new expr_macro(m, num, args)); } inline expr mk_mlocal(bool is_meta, name const & n, expr const & t) { return expr(new expr_mlocal(is_meta, n, t)); } inline expr mk_metavar(name const & n, expr const & t) { return mk_mlocal(true, n, t); } inline expr mk_local(name const & n, expr const & t) { return mk_mlocal(false, n, t); } @@ -399,6 +408,7 @@ inline expr_sort * to_sort(expr_cell * e) { lean_assert(is_sort(e)) inline expr_mlocal * to_mlocal(expr_cell * e) { lean_assert(is_mlocal(e)); return static_cast(e); } inline expr_mlocal * to_local(expr_cell * e) { lean_assert(is_local(e)); return static_cast(e); } inline expr_mlocal * to_metavar(expr_cell * e) { lean_assert(is_metavar(e)); return static_cast(e); } +inline expr_macro * to_macro(expr_cell * e) { lean_assert(is_macro(e)); return static_cast(e); } inline expr_var * to_var(expr const & e) { return to_var(e.raw()); } inline expr_const * to_constant(expr const & e) { return to_constant(e.raw()); } @@ -409,51 +419,57 @@ inline expr_sort * to_sort(expr const & e) { return to_sort(e.raw( inline expr_mlocal * to_mlocal(expr const & e) { return to_mlocal(e.raw()); } inline expr_mlocal * to_metavar(expr const & e) { return to_metavar(e.raw()); } inline expr_mlocal * to_local(expr const & e) { return to_local(e.raw()); } +inline expr_macro * to_macro(expr const & e) { return to_macro(e.raw()); } // ======================================= // ======================================= // Accessors -inline unsigned get_rc(expr_cell * e) { return e->get_rc(); } -inline bool is_shared(expr_cell * e) { return get_rc(e) > 1; } -inline unsigned var_idx(expr_cell * e) { return to_var(e)->get_vidx(); } -inline bool is_var(expr_cell * e, unsigned i) { return is_var(e) && var_idx(e) == i; } -inline name const & const_name(expr_cell * e) { return to_constant(e)->get_name(); } -inline levels const & const_level_params(expr_cell * e) { return to_constant(e)->get_level_params(); } -inline macro const & to_macro(expr_cell * e) { - lean_assert(is_macro(e)); return static_cast(e)->get_macro(); } -inline expr const & app_fn(expr_cell * e) { return to_app(e)->get_fn(); } -inline expr const & app_arg(expr_cell * e) { return to_app(e)->get_arg(); } -inline name const & binder_name(expr_cell * e) { return to_binder(e)->get_name(); } -inline expr const & binder_domain(expr_cell * e) { return to_binder(e)->get_domain(); } -inline expr const & binder_body(expr_cell * e) { return to_binder(e)->get_body(); } -inline level const & sort_level(expr_cell * e) { return to_sort(e)->get_level(); } -inline name const & let_name(expr_cell * e) { return to_let(e)->get_name(); } -inline expr const & let_value(expr_cell * e) { return to_let(e)->get_value(); } -inline expr const & let_type(expr_cell * e) { return to_let(e)->get_type(); } -inline expr const & let_body(expr_cell * e) { return to_let(e)->get_body(); } -inline name const & mlocal_name(expr_cell * e) { return to_mlocal(e)->get_name(); } -inline expr const & mlocal_type(expr_cell * e) { return to_mlocal(e)->get_type(); } +inline unsigned get_rc(expr_cell * e) { return e->get_rc(); } +inline bool is_shared(expr_cell * e) { return get_rc(e) > 1; } +inline unsigned var_idx(expr_cell * e) { return to_var(e)->get_vidx(); } +inline bool is_var(expr_cell * e, unsigned i) { return is_var(e) && var_idx(e) == i; } +inline name const & const_name(expr_cell * e) { return to_constant(e)->get_name(); } +inline levels const & const_level_params(expr_cell * e) { return to_constant(e)->get_level_params(); } +inline macro_definition const & macro_def(expr_cell * e) { return to_macro(e)->get_def(); } +inline expr const * macro_args(expr_cell * e) { return to_macro(e)->get_args(); } +inline expr const & macro_arg(expr_cell * e, unsigned i) { return to_macro(e)->get_arg(i); } +inline unsigned macro_num_args(expr_cell * e) { return to_macro(e)->get_num_args(); } +inline expr const & app_fn(expr_cell * e) { return to_app(e)->get_fn(); } +inline expr const & app_arg(expr_cell * e) { return to_app(e)->get_arg(); } +inline name const & binder_name(expr_cell * e) { return to_binder(e)->get_name(); } +inline expr const & binder_domain(expr_cell * e) { return to_binder(e)->get_domain(); } +inline expr const & binder_body(expr_cell * e) { return to_binder(e)->get_body(); } +inline level const & sort_level(expr_cell * e) { return to_sort(e)->get_level(); } +inline name const & let_name(expr_cell * e) { return to_let(e)->get_name(); } +inline expr const & let_value(expr_cell * e) { return to_let(e)->get_value(); } +inline expr const & let_type(expr_cell * e) { return to_let(e)->get_type(); } +inline expr const & let_body(expr_cell * e) { return to_let(e)->get_body(); } +inline name const & mlocal_name(expr_cell * e) { return to_mlocal(e)->get_name(); } +inline expr const & mlocal_type(expr_cell * e) { return to_mlocal(e)->get_type(); } -inline unsigned get_rc(expr const & e) { return e.raw()->get_rc(); } -inline bool is_shared(expr const & e) { return get_rc(e) > 1; } -inline unsigned var_idx(expr const & e) { return to_var(e)->get_vidx(); } -inline bool is_var(expr const & e, unsigned i) { return is_var(e) && var_idx(e) == i; } -inline name const & const_name(expr const & e) { return to_constant(e)->get_name(); } -inline levels const & const_level_params(expr const & e) { return to_constant(e)->get_level_params(); } -inline macro const & to_macro(expr const & e) { return to_macro(e.raw()); } -inline expr const & app_fn(expr const & e) { return to_app(e)->get_fn(); } -inline expr const & app_arg(expr const & e) { return to_app(e)->get_arg(); } -inline name const & binder_name(expr const & e) { return to_binder(e)->get_name(); } -inline expr const & binder_domain(expr const & e) { return to_binder(e)->get_domain(); } -inline expr const & binder_body(expr const & e) { return to_binder(e)->get_body(); } -inline level const & sort_level(expr const & e) { return to_sort(e)->get_level(); } -inline name const & let_name(expr const & e) { return to_let(e)->get_name(); } -inline expr const & let_value(expr const & e) { return to_let(e)->get_value(); } -inline expr const & let_type(expr const & e) { return to_let(e)->get_type(); } -inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); } -inline name const & mlocal_name(expr const & e) { return to_mlocal(e)->get_name(); } -inline expr const & mlocal_type(expr const & e) { return to_mlocal(e)->get_type(); } +inline unsigned get_rc(expr const & e) { return e.raw()->get_rc(); } +inline bool is_shared(expr const & e) { return get_rc(e) > 1; } +inline unsigned var_idx(expr const & e) { return to_var(e)->get_vidx(); } +inline bool is_var(expr const & e, unsigned i) { return is_var(e) && var_idx(e) == i; } +inline name const & const_name(expr const & e) { return to_constant(e)->get_name(); } +inline levels const & const_level_params(expr const & e) { return to_constant(e)->get_level_params(); } +inline macro_definition const & macro_def(expr const & e) { return to_macro(e)->get_def(); } +inline expr const * macro_args(expr const & e) { return to_macro(e)->get_args(); } +inline expr const & macro_arg(expr const & e, unsigned i) { return to_macro(e)->get_arg(i); } +inline unsigned macro_num_args(expr const & e) { return to_macro(e)->get_num_args(); } +inline expr const & app_fn(expr const & e) { return to_app(e)->get_fn(); } +inline expr const & app_arg(expr const & e) { return to_app(e)->get_arg(); } +inline name const & binder_name(expr const & e) { return to_binder(e)->get_name(); } +inline expr const & binder_domain(expr const & e) { return to_binder(e)->get_domain(); } +inline expr const & binder_body(expr const & e) { return to_binder(e)->get_body(); } +inline level const & sort_level(expr const & e) { return to_sort(e)->get_level(); } +inline name const & let_name(expr const & e) { return to_let(e)->get_name(); } +inline expr const & let_value(expr const & e) { return to_let(e)->get_value(); } +inline expr const & let_type(expr const & e) { return to_let(e)->get_type(); } +inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); } +inline name const & mlocal_name(expr const & e) { return to_mlocal(e)->get_name(); } +inline expr const & mlocal_type(expr const & e) { return to_mlocal(e)->get_type(); } inline bool is_constant(expr const & e, name const & n) { return is_constant(e) && const_name(e) == n; } inline bool has_metavar(expr const & e) { return e.has_metavar(); } @@ -522,6 +538,7 @@ expr update_let(expr const & e, expr const & new_type, expr const & new_val, exp expr update_mlocal(expr const & e, expr const & new_type); expr update_sort(expr const & e, level const & new_level); expr update_constant(expr const & e, levels const & new_levels); +expr update_macro(expr const & e, unsigned num, expr const * args); // ======================================= // ======================================= diff --git a/src/kernel/expr_eq_fn.cpp b/src/kernel/expr_eq_fn.cpp index f4a7f8e48..f5a81d303 100644 --- a/src/kernel/expr_eq_fn.cpp +++ b/src/kernel/expr_eq_fn.cpp @@ -43,7 +43,13 @@ bool expr_eq_fn::apply(expr const & a, expr const & b) { case expr_kind::Sort: return sort_level(a) == sort_level(b); case expr_kind::Macro: - return to_macro(a) == to_macro(b); + if (macro_def(a) != macro_def(b) || macro_num_args(a) != macro_num_args(b)) + return false; + for (unsigned i = 0; i < macro_num_args(a); i++) { + if (!apply(macro_arg(a, i), macro_arg(b, i))) + return false; + } + return true; case expr_kind::Let: return apply(let_type(a), let_type(b)) && diff --git a/src/kernel/for_each_fn.cpp b/src/kernel/for_each_fn.cpp index 3162eb874..ba690938d 100644 --- a/src/kernel/for_each_fn.cpp +++ b/src/kernel/for_each_fn.cpp @@ -22,7 +22,7 @@ void for_each_fn::apply(expr const & e, unsigned offset) { switch (e.kind()) { case expr_kind::Constant: case expr_kind::Var: - case expr_kind::Sort: case expr_kind::Macro: + case expr_kind::Sort: m_f(e, offset); goto begin_loop; default: @@ -43,11 +43,19 @@ void for_each_fn::apply(expr const & e, unsigned offset) { switch (e.kind()) { case expr_kind::Constant: case expr_kind::Var: - case expr_kind::Sort: case expr_kind::Macro: + case expr_kind::Sort: goto begin_loop; case expr_kind::Meta: case expr_kind::Local: todo.emplace_back(mlocal_type(e), offset); goto begin_loop; + case expr_kind::Macro: { + unsigned i = macro_num_args(e); + while (i > 0) { + --i; + todo.emplace_back(macro_arg(e, i), offset); + } + goto begin_loop; + } case expr_kind::App: todo.emplace_back(app_arg(e), offset); todo.emplace_back(app_fn(e), offset); diff --git a/src/kernel/formatter.cpp b/src/kernel/formatter.cpp index b86cf0664..83ef6c5ea 100644 --- a/src/kernel/formatter.cpp +++ b/src/kernel/formatter.cpp @@ -29,8 +29,13 @@ struct print_expr_fn { } } - void print_macro(expr const & a) { - to_macro(a).display(out()); + void print_macro(expr const & a, context const & c) { + if (macro_num_args(a) > 0) out() << "("; + macro_def(a).display(out()); + for (unsigned i = 0; i < macro_num_args(a); i++) { + out() << " "; print(macro_arg(a, i), c); + } + if (macro_num_args(a) > 0) out() << ")"; } void print_sort(expr const & a) { @@ -109,7 +114,7 @@ struct print_expr_fn { print_sort(a); break; case expr_kind::Macro: - print_macro(a); + print_macro(a, c); break; } } diff --git a/src/kernel/free_vars.cpp b/src/kernel/free_vars.cpp index 2a14b35ac..f6aebd9c1 100644 --- a/src/kernel/free_vars.cpp +++ b/src/kernel/free_vars.cpp @@ -57,7 +57,7 @@ protected: bool result = false; switch (e.kind()) { - case expr_kind::Constant: case expr_kind::Sort: case expr_kind::Macro: + case expr_kind::Constant: case expr_kind::Sort: case expr_kind::Var: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Meta: case expr_kind::Local: @@ -72,6 +72,14 @@ protected: case expr_kind::Let: result = apply(let_type(e), offset) || apply(let_value(e), offset) || apply(let_body(e), offset + 1); break; + case expr_kind::Macro: + for (unsigned i = 0; i < macro_num_args(e); i++) { + if (apply(macro_arg(e, i), offset)) { + result = true; + break; + } + } + break; } if (!result && shared) { diff --git a/src/kernel/max_sharing.cpp b/src/kernel/max_sharing.cpp index 7d8b91b00..845c1645b 100644 --- a/src/kernel/max_sharing.cpp +++ b/src/kernel/max_sharing.cpp @@ -29,7 +29,7 @@ struct max_sharing_fn::imp { expr res; switch (a.kind()) { case expr_kind::Constant: case expr_kind::Var: - case expr_kind::Sort: case expr_kind::Macro: + case expr_kind::Sort: res = a; break; case expr_kind::App: @@ -44,7 +44,13 @@ struct max_sharing_fn::imp { case expr_kind::Meta: case expr_kind::Local: res = update_mlocal(a, apply(mlocal_type(a))); break; - } + case expr_kind::Macro: { + buffer new_args; + for (unsigned i = 0; i < macro_num_args(a); i++) + new_args.push_back(macro_arg(a, i)); + res = update_macro(a, new_args.size(), new_args.data()); + break; + }} m_cache.insert(res); return res; } diff --git a/src/kernel/replace_fn.cpp b/src/kernel/replace_fn.cpp index dbfa35f98..2dc117648 100644 --- a/src/kernel/replace_fn.cpp +++ b/src/kernel/replace_fn.cpp @@ -77,7 +77,7 @@ expr replace_fn::operator()(expr const & e) { unsigned offset = f.m_offset; switch (e.kind()) { case expr_kind::Constant: case expr_kind::Sort: - case expr_kind::Macro: case expr_kind::Var: + case expr_kind::Var: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Meta: case expr_kind::Local: if (check_index(f, 0) && !visit(mlocal_type(e), offset)) @@ -111,6 +111,14 @@ expr replace_fn::operator()(expr const & e) { r = update_let(e, rs(-3), rs(-2), rs(-1)); pop_rs(3); break; + case expr_kind::Macro: + while (f.m_index < macro_num_args(e)) { + if (!visit(macro_arg(e, f.m_index), offset)) + goto begin_loop; + } + r = update_macro(e, macro_num_args(e), &rs(-macro_num_args(e))); + pop_rs(macro_num_args(e)); + break; } save_result(e, r, offset, f.m_shared); m_fs.pop_back(); diff --git a/src/kernel/replace_visitor.cpp b/src/kernel/replace_visitor.cpp index 8f97f24ea..999ff6dcd 100644 --- a/src/kernel/replace_visitor.cpp +++ b/src/kernel/replace_visitor.cpp @@ -11,7 +11,6 @@ Author: Leonardo de Moura namespace lean { expr replace_visitor::visit_sort(expr const & e, context const &) { lean_assert(is_sort(e)); return e; } -expr replace_visitor::visit_macro(expr const & e, context const &) { lean_assert(is_macro(e)); return e; } expr replace_visitor::visit_var(expr const & e, context const &) { lean_assert(is_var(e)); return e; } expr replace_visitor::visit_constant(expr const & e, context const &) { lean_assert(is_constant(e)); return e; } expr replace_visitor::visit_mlocal(expr const & e, context const & ctx) { @@ -41,6 +40,13 @@ expr replace_visitor::visit_let(expr const & e, context const & ctx) { expr new_b = visit(let_body(e), extend(ctx, let_name(e), new_t)); return update_let(e, new_t, new_v, new_b); } +expr replace_visitor::visit_macro(expr const & e, context const & ctx) { + lean_assert(is_macro(e)); + buffer new_args; + for (unsigned i = 0; i < macro_num_args(e); i++) + new_args.push_back(visit(macro_arg(e, i), ctx)); + return update_macro(e, new_args.size(), new_args.data()); +} expr replace_visitor::save_result(expr const & e, expr && r, bool shared) { if (shared) m_cache.insert(std::make_pair(e, r)); diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index b3113bedc..d452aab35 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -56,9 +56,9 @@ struct type_checker::imp { expr instantiate(expr const & e, unsigned n, expr const * s) { return max_sharing(lean::instantiate(e, n, s)); } expr instantiate(expr const & e, expr const & s) { return max_sharing(lean::instantiate(e, s)); } expr mk_rev_app(expr const & f, unsigned num, expr const * args) { return max_sharing(lean::mk_rev_app(f, num, args)); } - optional expand_macro(expr const & m, unsigned num, expr const * args) { + optional expand_macro(expr const & m) { lean_assert(is_macro(m)); - if (auto new_m = to_macro(m).expand(num, args)) + if (auto new_m = macro_def(m).expand(macro_num_args(m), macro_args(m))) return some_expr(max_sharing(*new_m)); else return none_expr(); @@ -134,7 +134,7 @@ struct type_checker::imp { case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Constant: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Macro: - if (auto m = expand_macro(e, 0, 0)) + if (auto m = expand_macro(e)) r = whnf_core(*m); else r = e; @@ -160,12 +160,6 @@ struct type_checker::imp { lean_assert(m <= num_args); r = whnf_core(mk_rev_app(instantiate(binder_body(f), m, args.data() + (num_args - m)), num_args - m, args.data())); break; - } else if (is_macro(f)) { - auto m = expand_macro(f, args.size(), args.data()); - if (m) { - r = whnf_core(*m); - break; - } } r = is_eqp(f, *it) ? e : mk_rev_app(f, args.size(), args.data()); break; @@ -664,10 +658,13 @@ struct type_checker::imp { r = instantiate_params(d.get_type(), ps, ls); break; } - case expr_kind::Macro: - r = to_macro(e).get_type(0, 0, 0); - if (!infer_only && to_macro(e).trust_level() <= m_env.trust_lvl()) { - optional m = to_macro(e).expand(0, 0); + case expr_kind::Macro: { + buffer arg_types; + for (unsigned i = 0; i < macro_num_args(e); i++) + arg_types.push_back(infer_type_core(macro_arg(e, i), infer_only)); + r = macro_def(e).get_type(macro_num_args(e), macro_args(e), arg_types.data()); + if (!infer_only && macro_def(e).trust_level() <= m_env.trust_lvl()) { + optional m = expand_macro(e); if (!m) throw_kernel_exception(m_env, "failed to expand macro", some_expr(e)); expr t = infer_type_core(*m, infer_only); @@ -676,6 +673,7 @@ struct type_checker::imp { throw_kernel_exception(m_env, g_macro_error_msg, some_expr(e)); } break; + } case expr_kind::Lambda: { if (!infer_only) { expr t = infer_type_core(binder_domain(e), infer_only); diff --git a/src/library/expr_lt.cpp b/src/library/expr_lt.cpp index cabc8990b..53327ebf9 100644 --- a/src/library/expr_lt.cpp +++ b/src/library/expr_lt.cpp @@ -92,7 +92,15 @@ bool is_lt(expr const & a, expr const & b, bool use_hash) { else return is_lt(mlocal_type(a), mlocal_type(b), use_hash); case expr_kind::Macro: - return to_macro(a) < to_macro(b); + if (macro_def(a) != macro_def(b)) + return macro_def(a) < macro_def(b); + if (macro_num_args(a) != macro_num_args(b)) + return macro_num_args(a) < macro_num_args(b); + for (unsigned i = 0; i < macro_num_args(a); i++) { + if (macro_arg(a, i) != macro_arg(b, i)) + return is_lt(macro_arg(a, i), macro_arg(b, i), use_hash); + } + return false; } lean_unreachable(); // LCOV_EXCL_LINE }