refactor(kernel): store macro arguments in the macro_expr

Before this commit, we "stored" macro arguments using applications.
This representation had some issues. Suppose we use [m a] to denote a macro
application. In the old representation, ([m a] b) and [m a b] would have
the same representation. Another problem is that some procedures (e.g., type inference)
would not have a clean implementation.

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-04-25 15:02:52 -07:00
parent 9d5ed2bae1
commit 4842ae4fc7
11 changed files with 240 additions and 110 deletions

View file

@ -179,21 +179,21 @@ void expr_let::dealloc(buffer<expr_cell*> & todelete) {
} }
expr_let::~expr_let() {} expr_let::~expr_let() {}
// Macro attachment // Macro definition
int macro::push_lua(lua_State *) const { return 0; } // NOLINT int macro_definition::push_lua(lua_State *) const { return 0; } // NOLINT
void macro::display(std::ostream & out) const { out << get_name(); } bool macro_definition::operator==(macro_definition const & other) const { return typeid(*this) == typeid(other); }
bool macro::operator==(macro const & other) const { return typeid(*this) == typeid(other); } bool macro_definition::operator<(macro_definition const & other) const {
bool macro::operator<(macro const & other) const {
if (get_name() == other.get_name()) if (get_name() == other.get_name())
return lt(other); return lt(other);
else else
return get_name() < other.get_name(); return get_name() < other.get_name();
} }
format macro::pp(formatter const &, options const &) const { return format(get_name()); } format macro_definition::pp(formatter const &, options const &) const { return format(get_name()); }
bool macro::is_atomic_pp(bool, bool) const { return true; } void macro_definition::display(std::ostream & out) const { out << get_name(); }
unsigned macro::hash() const { return get_name().hash(); } bool macro_definition::is_atomic_pp(bool, bool) const { return true; }
unsigned macro_definition::hash() const { return get_name().hash(); }
typedef std::unordered_map<std::string, macro::reader> macro_readers; typedef std::unordered_map<std::string, macro_definition::reader> macro_readers;
static std::unique_ptr<macro_readers> g_macro_readers; static std::unique_ptr<macro_readers> g_macro_readers;
macro_readers & get_macro_readers() { macro_readers & get_macro_readers() {
if (!g_macro_readers) if (!g_macro_readers)
@ -201,26 +201,61 @@ macro_readers & get_macro_readers() {
return *(g_macro_readers.get()); return *(g_macro_readers.get());
} }
void macro::register_deserializer(std::string const & k, macro::reader rd) { void macro_definition::register_deserializer(std::string const & k, macro_definition::reader rd) {
macro_readers & readers = get_macro_readers(); macro_readers & readers = get_macro_readers();
lean_assert(readers.find(k) == readers.end()); lean_assert(readers.find(k) == readers.end());
readers[k] = rd; readers[k] = rd;
} }
static expr read_macro(deserializer & d) { static expr read_macro_definition(deserializer & d, unsigned num, expr const * args) {
auto k = d.read_string(); auto k = d.read_string();
macro_readers & readers = get_macro_readers(); macro_readers & readers = get_macro_readers();
auto it = readers.find(k); auto it = readers.find(k);
lean_assert(it != readers.end()); lean_assert(it != readers.end());
return it->second(d); return it->second(d, num, args);
} }
expr_macro::expr_macro(macro * m): static unsigned max_depth(unsigned num, expr const * args) {
expr_cell(expr_kind::Macro, m->hash(), false, false, false), unsigned r = 0;
m_macro(m) { for (unsigned i = 0; i < num; i++) {
m_macro->inc_ref(); unsigned d = get_depth(args[i]);
if (d > r)
r = d;
}
return r;
}
static unsigned get_free_var_range(unsigned num, expr const * args) {
unsigned r = 0;
for (unsigned i = 0; i < num; i++) {
unsigned d = get_free_var_range(args[i]);
if (d > r)
r = d;
}
return r;
}
expr_macro::expr_macro(macro_definition * m, unsigned num, expr const * args):
expr_composite(expr_kind::Macro,
lean::hash(num, [&](unsigned i) { return args[i].hash(); }, m->hash()),
std::any_of(args, args+num, [](expr const & e) { return e.has_metavar(); }),
std::any_of(args, args+num, [](expr const & e) { return e.has_local(); }),
std::any_of(args, args+num, [](expr const & e) { return e.has_param_univ(); }),
max_depth(num, args) + 1,
get_free_var_range(num, args)),
m_definition(m),
m_num_args(num) {
m_definition->inc_ref();
m_args = new expr[num];
for (unsigned i = 0; i < m_num_args; i++)
m_args[i] = args[i];
}
void expr_macro::dealloc(buffer<expr_cell*> & todelete) {
m_definition->dec_ref();
for (unsigned i = 0; i < m_num_args; i++) dec_ref(m_args[i], todelete);
delete(this);
} }
expr_macro::~expr_macro() { expr_macro::~expr_macro() {
m_macro->dec_ref(); delete[] m_args;
} }
void expr_cell::dealloc() { void expr_cell::dealloc() {
@ -233,7 +268,7 @@ void expr_cell::dealloc() {
lean_assert(it->get_rc() == 0); lean_assert(it->get_rc() == 0);
switch (it->kind()) { switch (it->kind()) {
case expr_kind::Var: delete static_cast<expr_var*>(it); break; case expr_kind::Var: delete static_cast<expr_var*>(it); break;
case expr_kind::Macro: delete static_cast<expr_macro*>(it); break; case expr_kind::Macro: static_cast<expr_macro*>(it)->dealloc(todo); break;
case expr_kind::Meta: case expr_kind::Meta:
case expr_kind::Local: static_cast<expr_mlocal*>(it)->dealloc(todo); break; case expr_kind::Local: static_cast<expr_mlocal*>(it)->dealloc(todo); break;
case expr_kind::Constant: delete static_cast<expr_const*>(it); break; case expr_kind::Constant: delete static_cast<expr_const*>(it); break;
@ -290,9 +325,9 @@ expr mk_Type() { return Type; }
unsigned get_depth(expr const & e) { unsigned get_depth(expr const & e) {
switch (e.kind()) { switch (e.kind()) {
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Sort: case expr_kind::Var: case expr_kind::Constant: case expr_kind::Sort:
case expr_kind::Meta: case expr_kind::Local: case expr_kind::Macro: case expr_kind::Meta: case expr_kind::Local:
return 1; return 1;
case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Macro:
case expr_kind::App: case expr_kind::Let: case expr_kind::App: case expr_kind::Let:
return static_cast<expr_composite*>(e.raw())->m_depth; return static_cast<expr_composite*>(e.raw())->m_depth;
} }
@ -303,12 +338,13 @@ unsigned get_free_var_range(expr const & e) {
switch (e.kind()) { switch (e.kind()) {
case expr_kind::Var: case expr_kind::Var:
return var_idx(e) + 1; return var_idx(e) + 1;
case expr_kind::Constant: case expr_kind::Sort: case expr_kind::Macro: case expr_kind::Constant: case expr_kind::Sort:
return 0; return 0;
case expr_kind::Meta: case expr_kind::Local: case expr_kind::Meta: case expr_kind::Local:
return get_free_var_range(mlocal_type(e)); return get_free_var_range(mlocal_type(e));
case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Lambda: case expr_kind::Pi:
case expr_kind::App: case expr_kind::Let: case expr_kind::App: case expr_kind::Let:
case expr_kind::Macro:
return static_cast<expr_composite*>(e.raw())->m_free_var_range; return static_cast<expr_composite*>(e.raw())->m_free_var_range;
} }
lean_unreachable(); // LCOV_EXCL_LINE lean_unreachable(); // LCOV_EXCL_LINE
@ -370,11 +406,26 @@ expr update_constant(expr const & e, levels const & new_levels) {
return e; return e;
} }
expr update_macro(expr const & e, unsigned num, expr const * args) {
if (num == macro_num_args(e)) {
unsigned i = 0;
for (i = 0; i < num; i++) {
if (!is_eqp(macro_arg(e, i), args[i]))
break;
}
if (i == num)
return e;
}
return mk_macro(to_macro(e)->m_definition, num, args);
}
bool is_atomic(expr const & e) { bool is_atomic(expr const & e) {
switch (e.kind()) { switch (e.kind()) {
case expr_kind::Constant: case expr_kind::Sort: case expr_kind::Constant: case expr_kind::Sort:
case expr_kind::Macro: case expr_kind::Var: case expr_kind::Var:
return true; return true;
case expr_kind::Macro:
return to_macro(e)->get_num_args() == 0;
case expr_kind::App: case expr_kind::Let: case expr_kind::App: case expr_kind::Let:
case expr_kind::Meta: case expr_kind::Local: case expr_kind::Meta: case expr_kind::Local:
case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Lambda: case expr_kind::Pi:
@ -399,7 +450,7 @@ expr copy(expr const & a) {
case expr_kind::Var: return mk_var(var_idx(a)); case expr_kind::Var: return mk_var(var_idx(a));
case expr_kind::Constant: return mk_constant(const_name(a), const_level_params(a)); case expr_kind::Constant: return mk_constant(const_name(a), const_level_params(a));
case expr_kind::Sort: return mk_sort(sort_level(a)); case expr_kind::Sort: return mk_sort(sort_level(a));
case expr_kind::Macro: return mk_macro(static_cast<expr_macro*>(a.raw())->m_macro); case expr_kind::Macro: return mk_macro(to_macro(a)->m_definition, macro_num_args(a), macro_args(a));
case expr_kind::App: return mk_app(app_fn(a), app_arg(a)); case expr_kind::App: return mk_app(app_fn(a), app_arg(a));
case expr_kind::Lambda: return mk_lambda(binder_name(a), binder_domain(a), binder_body(a)); case expr_kind::Lambda: return mk_lambda(binder_name(a), binder_domain(a), binder_body(a));
case expr_kind::Pi: return mk_pi(binder_name(a), binder_domain(a), binder_body(a)); case expr_kind::Pi: return mk_pi(binder_name(a), binder_domain(a), binder_body(a));
@ -429,7 +480,11 @@ class expr_serializer : public object_serializer<expr, expr_hash_alloc, expr_eqp
s << sort_level(a); s << sort_level(a);
break; break;
case expr_kind::Macro: case expr_kind::Macro:
to_macro(a).write(s); s << macro_num_args(a);
for (unsigned i = 0; i < macro_num_args(a); i++) {
write_core(macro_arg(a, i));
}
macro_def(a).write(s);
break; break;
case expr_kind::App: case expr_kind::App:
write_core(app_fn(a)); write_core(app_arg(a)); write_core(app_fn(a)); write_core(app_arg(a));
@ -475,9 +530,14 @@ public:
} }
case expr_kind::Sort: case expr_kind::Sort:
return mk_sort(read_level(d)); return mk_sort(read_level(d));
break; case expr_kind::Macro: {
case expr_kind::Macro: unsigned n = d.read_unsigned();
return read_macro(d); buffer<expr> args;
for (unsigned i = 0; i < n; i++) {
args.push_back(read());
}
return read_macro_definition(d, args.size(), args.data());
}
case expr_kind::App: { case expr_kind::App: {
expr f = read(); expr f = read();
return mk_app(f, read()); return mk_app(f, read());

View file

@ -70,7 +70,7 @@ public:
bool has_param_univ() const { return m_has_param_univ; } bool has_param_univ() const { return m_has_param_univ; }
}; };
class macro; class macro_definition;
/** /**
\brief Exprs for encoding formulas/expressions, types and proofs. \brief Exprs for encoding formulas/expressions, types and proofs.
@ -119,7 +119,7 @@ public:
friend expr mk_proj(bool fst, expr const & p); friend expr mk_proj(bool fst, expr const & p);
friend expr mk_binder(expr_kind k, name const & n, expr const & t, expr const & e); friend expr mk_binder(expr_kind k, name const & n, expr const & t, expr const & e);
friend expr mk_let(name const & n, expr const & t, expr const & v, expr const & e); friend expr mk_let(name const & n, expr const & t, expr const & v, expr const & e);
friend expr mk_macro(macro * m); friend expr mk_macro(macro_definition * m, unsigned num, expr const * args);
friend bool is_eqp(expr const & a, expr const & b) { return a.m_ptr == b.m_ptr; } friend bool is_eqp(expr const & a, expr const & b) { return a.m_ptr == b.m_ptr; }
// Overloaded operator() can be used to create applications // Overloaded operator() can be used to create applications
@ -245,8 +245,8 @@ public:
class formatter; class formatter;
/** \brief Base class for macro attachments */ /** \brief Base class for macro definition attachments */
class macro { class macro_definition {
void dealloc() { delete this; } void dealloc() { delete this; }
MK_LEAN_RC(); MK_LEAN_RC();
protected: protected:
@ -255,39 +255,48 @@ protected:
attachments. It is invoked by operator<, and it is only invoked when attachments. It is invoked by operator<, and it is only invoked when
<tt>get_name() == other.get_name()</tt> <tt>get_name() == other.get_name()</tt>
*/ */
virtual bool lt(macro const &) const { return false; } virtual bool lt(macro_definition const &) const { return false; }
public: public:
macro():m_rc(0) {} macro_definition():m_rc(0) {}
virtual ~macro() {} virtual ~macro_definition() {}
virtual name get_name() const = 0; virtual name get_name() const = 0;
virtual expr get_type(unsigned num_args, expr const * args, expr const * arg_types) const = 0; virtual expr get_type(unsigned num, expr const * args, expr const * arg_types) const = 0;
virtual optional<expr> expand1(unsigned num_args, expr const * args) const = 0; virtual optional<expr> expand1(unsigned num, expr const * args) const = 0;
virtual optional<expr> expand(unsigned num_args, expr const * args) const = 0; virtual optional<expr> expand(unsigned num, expr const * args) const = 0;
virtual unsigned trust_level() const = 0; virtual unsigned trust_level() const = 0;
virtual int push_lua(lua_State * L) const; virtual int push_lua(lua_State * L) const;
virtual bool operator==(macro const & other) const; virtual bool operator==(macro_definition const & other) const;
bool operator<(macro const & other) const; bool operator!=(macro_definition const & other) const { return !operator==(other); }
bool operator<(macro_definition const & other) const;
virtual void display(std::ostream & out) const; virtual void display(std::ostream & out) const;
virtual format pp(formatter const & fmt, options const & opts) const; virtual format pp(formatter const & fmt, options const & opts) const;
virtual bool is_atomic_pp(bool unicode, bool coercion) const; virtual bool is_atomic_pp(bool unicode, bool coercion) const;
virtual unsigned hash() const; virtual unsigned hash() const;
virtual void write(serializer & s) const = 0; virtual void write(serializer & s) const = 0;
typedef std::function<expr(deserializer&)> reader; typedef std::function<expr(deserializer&, unsigned, expr const *)> reader;
static void register_deserializer(std::string const & k, reader rd); static void register_deserializer(std::string const & k, reader rd);
struct register_deserializer_fn { struct register_deserializer_fn {
register_deserializer_fn(std::string const & k, macro::reader rd) { macro::register_deserializer(k, rd); } register_deserializer_fn(std::string const & k, macro_definition::reader rd) { macro_definition::register_deserializer(k, rd); }
}; };
}; };
/** \brief Macro attachments */ /** \brief Macro attachments */
class expr_macro : public expr_cell { class expr_macro : public expr_composite {
macro * m_macro; macro_definition * m_definition;
unsigned m_num_args;
expr * m_args;
friend class expr_cell;
friend expr copy(expr const & a); friend expr copy(expr const & a);
friend expr update_macro(expr const & e, unsigned num, expr const * args);
void dealloc(buffer<expr_cell*> & todelete);
public: public:
expr_macro(macro * v); expr_macro(macro_definition * v, unsigned num, expr const * args);
~expr_macro(); ~expr_macro();
macro const & get_macro() const { return *m_macro; } macro_definition const & get_def() const { return *m_definition; }
expr const * get_args() const { return m_args; }
expr const & get_arg(unsigned idx) const { lean_assert(idx < m_num_args); return m_args[idx]; }
unsigned get_num_args() const { return m_num_args; }
}; };
// ======================================= // =======================================
@ -331,7 +340,7 @@ inline expr Var(unsigned idx) { return mk_var(idx); }
inline expr mk_constant(name const & n, levels const & ls) { return expr(new expr_const(n, ls)); } inline expr mk_constant(name const & n, levels const & ls) { return expr(new expr_const(n, ls)); }
inline expr mk_constant(name const & n) { return mk_constant(n, levels()); } inline expr mk_constant(name const & n) { return mk_constant(n, levels()); }
inline expr Const(name const & n) { return mk_constant(n); } inline expr Const(name const & n) { return mk_constant(n); }
inline expr mk_macro(macro * m) { return expr(new expr_macro(m)); } inline expr mk_macro(macro_definition * m, unsigned num = 0, expr const * args = nullptr) { return expr(new expr_macro(m, num, args)); }
inline expr mk_mlocal(bool is_meta, name const & n, expr const & t) { return expr(new expr_mlocal(is_meta, n, t)); } inline expr mk_mlocal(bool is_meta, name const & n, expr const & t) { return expr(new expr_mlocal(is_meta, n, t)); }
inline expr mk_metavar(name const & n, expr const & t) { return mk_mlocal(true, n, t); } inline expr mk_metavar(name const & n, expr const & t) { return mk_mlocal(true, n, t); }
inline expr mk_local(name const & n, expr const & t) { return mk_mlocal(false, n, t); } inline expr mk_local(name const & n, expr const & t) { return mk_mlocal(false, n, t); }
@ -399,6 +408,7 @@ inline expr_sort * to_sort(expr_cell * e) { lean_assert(is_sort(e))
inline expr_mlocal * to_mlocal(expr_cell * e) { lean_assert(is_mlocal(e)); return static_cast<expr_mlocal*>(e); } inline expr_mlocal * to_mlocal(expr_cell * e) { lean_assert(is_mlocal(e)); return static_cast<expr_mlocal*>(e); }
inline expr_mlocal * to_local(expr_cell * e) { lean_assert(is_local(e)); return static_cast<expr_mlocal*>(e); } inline expr_mlocal * to_local(expr_cell * e) { lean_assert(is_local(e)); return static_cast<expr_mlocal*>(e); }
inline expr_mlocal * to_metavar(expr_cell * e) { lean_assert(is_metavar(e)); return static_cast<expr_mlocal*>(e); } inline expr_mlocal * to_metavar(expr_cell * e) { lean_assert(is_metavar(e)); return static_cast<expr_mlocal*>(e); }
inline expr_macro * to_macro(expr_cell * e) { lean_assert(is_macro(e)); return static_cast<expr_macro*>(e); }
inline expr_var * to_var(expr const & e) { return to_var(e.raw()); } inline expr_var * to_var(expr const & e) { return to_var(e.raw()); }
inline expr_const * to_constant(expr const & e) { return to_constant(e.raw()); } inline expr_const * to_constant(expr const & e) { return to_constant(e.raw()); }
@ -409,51 +419,57 @@ inline expr_sort * to_sort(expr const & e) { return to_sort(e.raw(
inline expr_mlocal * to_mlocal(expr const & e) { return to_mlocal(e.raw()); } inline expr_mlocal * to_mlocal(expr const & e) { return to_mlocal(e.raw()); }
inline expr_mlocal * to_metavar(expr const & e) { return to_metavar(e.raw()); } inline expr_mlocal * to_metavar(expr const & e) { return to_metavar(e.raw()); }
inline expr_mlocal * to_local(expr const & e) { return to_local(e.raw()); } inline expr_mlocal * to_local(expr const & e) { return to_local(e.raw()); }
inline expr_macro * to_macro(expr const & e) { return to_macro(e.raw()); }
// ======================================= // =======================================
// ======================================= // =======================================
// Accessors // Accessors
inline unsigned get_rc(expr_cell * e) { return e->get_rc(); } inline unsigned get_rc(expr_cell * e) { return e->get_rc(); }
inline bool is_shared(expr_cell * e) { return get_rc(e) > 1; } inline bool is_shared(expr_cell * e) { return get_rc(e) > 1; }
inline unsigned var_idx(expr_cell * e) { return to_var(e)->get_vidx(); } inline unsigned var_idx(expr_cell * e) { return to_var(e)->get_vidx(); }
inline bool is_var(expr_cell * e, unsigned i) { return is_var(e) && var_idx(e) == i; } inline bool is_var(expr_cell * e, unsigned i) { return is_var(e) && var_idx(e) == i; }
inline name const & const_name(expr_cell * e) { return to_constant(e)->get_name(); } inline name const & const_name(expr_cell * e) { return to_constant(e)->get_name(); }
inline levels const & const_level_params(expr_cell * e) { return to_constant(e)->get_level_params(); } inline levels const & const_level_params(expr_cell * e) { return to_constant(e)->get_level_params(); }
inline macro const & to_macro(expr_cell * e) { inline macro_definition const & macro_def(expr_cell * e) { return to_macro(e)->get_def(); }
lean_assert(is_macro(e)); return static_cast<expr_macro*>(e)->get_macro(); } inline expr const * macro_args(expr_cell * e) { return to_macro(e)->get_args(); }
inline expr const & app_fn(expr_cell * e) { return to_app(e)->get_fn(); } inline expr const & macro_arg(expr_cell * e, unsigned i) { return to_macro(e)->get_arg(i); }
inline expr const & app_arg(expr_cell * e) { return to_app(e)->get_arg(); } inline unsigned macro_num_args(expr_cell * e) { return to_macro(e)->get_num_args(); }
inline name const & binder_name(expr_cell * e) { return to_binder(e)->get_name(); } inline expr const & app_fn(expr_cell * e) { return to_app(e)->get_fn(); }
inline expr const & binder_domain(expr_cell * e) { return to_binder(e)->get_domain(); } inline expr const & app_arg(expr_cell * e) { return to_app(e)->get_arg(); }
inline expr const & binder_body(expr_cell * e) { return to_binder(e)->get_body(); } inline name const & binder_name(expr_cell * e) { return to_binder(e)->get_name(); }
inline level const & sort_level(expr_cell * e) { return to_sort(e)->get_level(); } inline expr const & binder_domain(expr_cell * e) { return to_binder(e)->get_domain(); }
inline name const & let_name(expr_cell * e) { return to_let(e)->get_name(); } inline expr const & binder_body(expr_cell * e) { return to_binder(e)->get_body(); }
inline expr const & let_value(expr_cell * e) { return to_let(e)->get_value(); } inline level const & sort_level(expr_cell * e) { return to_sort(e)->get_level(); }
inline expr const & let_type(expr_cell * e) { return to_let(e)->get_type(); } inline name const & let_name(expr_cell * e) { return to_let(e)->get_name(); }
inline expr const & let_body(expr_cell * e) { return to_let(e)->get_body(); } inline expr const & let_value(expr_cell * e) { return to_let(e)->get_value(); }
inline name const & mlocal_name(expr_cell * e) { return to_mlocal(e)->get_name(); } inline expr const & let_type(expr_cell * e) { return to_let(e)->get_type(); }
inline expr const & mlocal_type(expr_cell * e) { return to_mlocal(e)->get_type(); } inline expr const & let_body(expr_cell * e) { return to_let(e)->get_body(); }
inline name const & mlocal_name(expr_cell * e) { return to_mlocal(e)->get_name(); }
inline expr const & mlocal_type(expr_cell * e) { return to_mlocal(e)->get_type(); }
inline unsigned get_rc(expr const & e) { return e.raw()->get_rc(); } inline unsigned get_rc(expr const & e) { return e.raw()->get_rc(); }
inline bool is_shared(expr const & e) { return get_rc(e) > 1; } inline bool is_shared(expr const & e) { return get_rc(e) > 1; }
inline unsigned var_idx(expr const & e) { return to_var(e)->get_vidx(); } inline unsigned var_idx(expr const & e) { return to_var(e)->get_vidx(); }
inline bool is_var(expr const & e, unsigned i) { return is_var(e) && var_idx(e) == i; } inline bool is_var(expr const & e, unsigned i) { return is_var(e) && var_idx(e) == i; }
inline name const & const_name(expr const & e) { return to_constant(e)->get_name(); } inline name const & const_name(expr const & e) { return to_constant(e)->get_name(); }
inline levels const & const_level_params(expr const & e) { return to_constant(e)->get_level_params(); } inline levels const & const_level_params(expr const & e) { return to_constant(e)->get_level_params(); }
inline macro const & to_macro(expr const & e) { return to_macro(e.raw()); } inline macro_definition const & macro_def(expr const & e) { return to_macro(e)->get_def(); }
inline expr const & app_fn(expr const & e) { return to_app(e)->get_fn(); } inline expr const * macro_args(expr const & e) { return to_macro(e)->get_args(); }
inline expr const & app_arg(expr const & e) { return to_app(e)->get_arg(); } inline expr const & macro_arg(expr const & e, unsigned i) { return to_macro(e)->get_arg(i); }
inline name const & binder_name(expr const & e) { return to_binder(e)->get_name(); } inline unsigned macro_num_args(expr const & e) { return to_macro(e)->get_num_args(); }
inline expr const & binder_domain(expr const & e) { return to_binder(e)->get_domain(); } inline expr const & app_fn(expr const & e) { return to_app(e)->get_fn(); }
inline expr const & binder_body(expr const & e) { return to_binder(e)->get_body(); } inline expr const & app_arg(expr const & e) { return to_app(e)->get_arg(); }
inline level const & sort_level(expr const & e) { return to_sort(e)->get_level(); } inline name const & binder_name(expr const & e) { return to_binder(e)->get_name(); }
inline name const & let_name(expr const & e) { return to_let(e)->get_name(); } inline expr const & binder_domain(expr const & e) { return to_binder(e)->get_domain(); }
inline expr const & let_value(expr const & e) { return to_let(e)->get_value(); } inline expr const & binder_body(expr const & e) { return to_binder(e)->get_body(); }
inline expr const & let_type(expr const & e) { return to_let(e)->get_type(); } inline level const & sort_level(expr const & e) { return to_sort(e)->get_level(); }
inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); } inline name const & let_name(expr const & e) { return to_let(e)->get_name(); }
inline name const & mlocal_name(expr const & e) { return to_mlocal(e)->get_name(); } inline expr const & let_value(expr const & e) { return to_let(e)->get_value(); }
inline expr const & mlocal_type(expr const & e) { return to_mlocal(e)->get_type(); } inline expr const & let_type(expr const & e) { return to_let(e)->get_type(); }
inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); }
inline name const & mlocal_name(expr const & e) { return to_mlocal(e)->get_name(); }
inline expr const & mlocal_type(expr const & e) { return to_mlocal(e)->get_type(); }
inline bool is_constant(expr const & e, name const & n) { return is_constant(e) && const_name(e) == n; } inline bool is_constant(expr const & e, name const & n) { return is_constant(e) && const_name(e) == n; }
inline bool has_metavar(expr const & e) { return e.has_metavar(); } inline bool has_metavar(expr const & e) { return e.has_metavar(); }
@ -522,6 +538,7 @@ expr update_let(expr const & e, expr const & new_type, expr const & new_val, exp
expr update_mlocal(expr const & e, expr const & new_type); expr update_mlocal(expr const & e, expr const & new_type);
expr update_sort(expr const & e, level const & new_level); expr update_sort(expr const & e, level const & new_level);
expr update_constant(expr const & e, levels const & new_levels); expr update_constant(expr const & e, levels const & new_levels);
expr update_macro(expr const & e, unsigned num, expr const * args);
// ======================================= // =======================================
// ======================================= // =======================================

View file

@ -43,7 +43,13 @@ bool expr_eq_fn::apply(expr const & a, expr const & b) {
case expr_kind::Sort: case expr_kind::Sort:
return sort_level(a) == sort_level(b); return sort_level(a) == sort_level(b);
case expr_kind::Macro: case expr_kind::Macro:
return to_macro(a) == to_macro(b); if (macro_def(a) != macro_def(b) || macro_num_args(a) != macro_num_args(b))
return false;
for (unsigned i = 0; i < macro_num_args(a); i++) {
if (!apply(macro_arg(a, i), macro_arg(b, i)))
return false;
}
return true;
case expr_kind::Let: case expr_kind::Let:
return return
apply(let_type(a), let_type(b)) && apply(let_type(a), let_type(b)) &&

View file

@ -22,7 +22,7 @@ void for_each_fn::apply(expr const & e, unsigned offset) {
switch (e.kind()) { switch (e.kind()) {
case expr_kind::Constant: case expr_kind::Var: case expr_kind::Constant: case expr_kind::Var:
case expr_kind::Sort: case expr_kind::Macro: case expr_kind::Sort:
m_f(e, offset); m_f(e, offset);
goto begin_loop; goto begin_loop;
default: default:
@ -43,11 +43,19 @@ void for_each_fn::apply(expr const & e, unsigned offset) {
switch (e.kind()) { switch (e.kind()) {
case expr_kind::Constant: case expr_kind::Var: case expr_kind::Constant: case expr_kind::Var:
case expr_kind::Sort: case expr_kind::Macro: case expr_kind::Sort:
goto begin_loop; goto begin_loop;
case expr_kind::Meta: case expr_kind::Local: case expr_kind::Meta: case expr_kind::Local:
todo.emplace_back(mlocal_type(e), offset); todo.emplace_back(mlocal_type(e), offset);
goto begin_loop; goto begin_loop;
case expr_kind::Macro: {
unsigned i = macro_num_args(e);
while (i > 0) {
--i;
todo.emplace_back(macro_arg(e, i), offset);
}
goto begin_loop;
}
case expr_kind::App: case expr_kind::App:
todo.emplace_back(app_arg(e), offset); todo.emplace_back(app_arg(e), offset);
todo.emplace_back(app_fn(e), offset); todo.emplace_back(app_fn(e), offset);

View file

@ -29,8 +29,13 @@ struct print_expr_fn {
} }
} }
void print_macro(expr const & a) { void print_macro(expr const & a, context const & c) {
to_macro(a).display(out()); if (macro_num_args(a) > 0) out() << "(";
macro_def(a).display(out());
for (unsigned i = 0; i < macro_num_args(a); i++) {
out() << " "; print(macro_arg(a, i), c);
}
if (macro_num_args(a) > 0) out() << ")";
} }
void print_sort(expr const & a) { void print_sort(expr const & a) {
@ -109,7 +114,7 @@ struct print_expr_fn {
print_sort(a); print_sort(a);
break; break;
case expr_kind::Macro: case expr_kind::Macro:
print_macro(a); print_macro(a, c);
break; break;
} }
} }

View file

@ -57,7 +57,7 @@ protected:
bool result = false; bool result = false;
switch (e.kind()) { switch (e.kind()) {
case expr_kind::Constant: case expr_kind::Sort: case expr_kind::Macro: case expr_kind::Constant: case expr_kind::Sort:
case expr_kind::Var: case expr_kind::Var:
lean_unreachable(); // LCOV_EXCL_LINE lean_unreachable(); // LCOV_EXCL_LINE
case expr_kind::Meta: case expr_kind::Local: case expr_kind::Meta: case expr_kind::Local:
@ -72,6 +72,14 @@ protected:
case expr_kind::Let: case expr_kind::Let:
result = apply(let_type(e), offset) || apply(let_value(e), offset) || apply(let_body(e), offset + 1); result = apply(let_type(e), offset) || apply(let_value(e), offset) || apply(let_body(e), offset + 1);
break; break;
case expr_kind::Macro:
for (unsigned i = 0; i < macro_num_args(e); i++) {
if (apply(macro_arg(e, i), offset)) {
result = true;
break;
}
}
break;
} }
if (!result && shared) { if (!result && shared) {

View file

@ -29,7 +29,7 @@ struct max_sharing_fn::imp {
expr res; expr res;
switch (a.kind()) { switch (a.kind()) {
case expr_kind::Constant: case expr_kind::Var: case expr_kind::Constant: case expr_kind::Var:
case expr_kind::Sort: case expr_kind::Macro: case expr_kind::Sort:
res = a; res = a;
break; break;
case expr_kind::App: case expr_kind::App:
@ -44,7 +44,13 @@ struct max_sharing_fn::imp {
case expr_kind::Meta: case expr_kind::Local: case expr_kind::Meta: case expr_kind::Local:
res = update_mlocal(a, apply(mlocal_type(a))); res = update_mlocal(a, apply(mlocal_type(a)));
break; break;
} case expr_kind::Macro: {
buffer<expr> new_args;
for (unsigned i = 0; i < macro_num_args(a); i++)
new_args.push_back(macro_arg(a, i));
res = update_macro(a, new_args.size(), new_args.data());
break;
}}
m_cache.insert(res); m_cache.insert(res);
return res; return res;
} }

View file

@ -77,7 +77,7 @@ expr replace_fn::operator()(expr const & e) {
unsigned offset = f.m_offset; unsigned offset = f.m_offset;
switch (e.kind()) { switch (e.kind()) {
case expr_kind::Constant: case expr_kind::Sort: case expr_kind::Constant: case expr_kind::Sort:
case expr_kind::Macro: case expr_kind::Var: case expr_kind::Var:
lean_unreachable(); // LCOV_EXCL_LINE lean_unreachable(); // LCOV_EXCL_LINE
case expr_kind::Meta: case expr_kind::Local: case expr_kind::Meta: case expr_kind::Local:
if (check_index(f, 0) && !visit(mlocal_type(e), offset)) if (check_index(f, 0) && !visit(mlocal_type(e), offset))
@ -111,6 +111,14 @@ expr replace_fn::operator()(expr const & e) {
r = update_let(e, rs(-3), rs(-2), rs(-1)); r = update_let(e, rs(-3), rs(-2), rs(-1));
pop_rs(3); pop_rs(3);
break; break;
case expr_kind::Macro:
while (f.m_index < macro_num_args(e)) {
if (!visit(macro_arg(e, f.m_index), offset))
goto begin_loop;
}
r = update_macro(e, macro_num_args(e), &rs(-macro_num_args(e)));
pop_rs(macro_num_args(e));
break;
} }
save_result(e, r, offset, f.m_shared); save_result(e, r, offset, f.m_shared);
m_fs.pop_back(); m_fs.pop_back();

View file

@ -11,7 +11,6 @@ Author: Leonardo de Moura
namespace lean { namespace lean {
expr replace_visitor::visit_sort(expr const & e, context const &) { lean_assert(is_sort(e)); return e; } expr replace_visitor::visit_sort(expr const & e, context const &) { lean_assert(is_sort(e)); return e; }
expr replace_visitor::visit_macro(expr const & e, context const &) { lean_assert(is_macro(e)); return e; }
expr replace_visitor::visit_var(expr const & e, context const &) { lean_assert(is_var(e)); return e; } expr replace_visitor::visit_var(expr const & e, context const &) { lean_assert(is_var(e)); return e; }
expr replace_visitor::visit_constant(expr const & e, context const &) { lean_assert(is_constant(e)); return e; } expr replace_visitor::visit_constant(expr const & e, context const &) { lean_assert(is_constant(e)); return e; }
expr replace_visitor::visit_mlocal(expr const & e, context const & ctx) { expr replace_visitor::visit_mlocal(expr const & e, context const & ctx) {
@ -41,6 +40,13 @@ expr replace_visitor::visit_let(expr const & e, context const & ctx) {
expr new_b = visit(let_body(e), extend(ctx, let_name(e), new_t)); expr new_b = visit(let_body(e), extend(ctx, let_name(e), new_t));
return update_let(e, new_t, new_v, new_b); return update_let(e, new_t, new_v, new_b);
} }
expr replace_visitor::visit_macro(expr const & e, context const & ctx) {
lean_assert(is_macro(e));
buffer<expr> new_args;
for (unsigned i = 0; i < macro_num_args(e); i++)
new_args.push_back(visit(macro_arg(e, i), ctx));
return update_macro(e, new_args.size(), new_args.data());
}
expr replace_visitor::save_result(expr const & e, expr && r, bool shared) { expr replace_visitor::save_result(expr const & e, expr && r, bool shared) {
if (shared) if (shared)
m_cache.insert(std::make_pair(e, r)); m_cache.insert(std::make_pair(e, r));

View file

@ -56,9 +56,9 @@ struct type_checker::imp {
expr instantiate(expr const & e, unsigned n, expr const * s) { return max_sharing(lean::instantiate(e, n, s)); } expr instantiate(expr const & e, unsigned n, expr const * s) { return max_sharing(lean::instantiate(e, n, s)); }
expr instantiate(expr const & e, expr const & s) { return max_sharing(lean::instantiate(e, s)); } expr instantiate(expr const & e, expr const & s) { return max_sharing(lean::instantiate(e, s)); }
expr mk_rev_app(expr const & f, unsigned num, expr const * args) { return max_sharing(lean::mk_rev_app(f, num, args)); } expr mk_rev_app(expr const & f, unsigned num, expr const * args) { return max_sharing(lean::mk_rev_app(f, num, args)); }
optional<expr> expand_macro(expr const & m, unsigned num, expr const * args) { optional<expr> expand_macro(expr const & m) {
lean_assert(is_macro(m)); lean_assert(is_macro(m));
if (auto new_m = to_macro(m).expand(num, args)) if (auto new_m = macro_def(m).expand(macro_num_args(m), macro_args(m)))
return some_expr(max_sharing(*new_m)); return some_expr(max_sharing(*new_m));
else else
return none_expr(); return none_expr();
@ -134,7 +134,7 @@ struct type_checker::imp {
case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Constant: case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Constant:
lean_unreachable(); // LCOV_EXCL_LINE lean_unreachable(); // LCOV_EXCL_LINE
case expr_kind::Macro: case expr_kind::Macro:
if (auto m = expand_macro(e, 0, 0)) if (auto m = expand_macro(e))
r = whnf_core(*m); r = whnf_core(*m);
else else
r = e; r = e;
@ -160,12 +160,6 @@ struct type_checker::imp {
lean_assert(m <= num_args); lean_assert(m <= num_args);
r = whnf_core(mk_rev_app(instantiate(binder_body(f), m, args.data() + (num_args - m)), num_args - m, args.data())); r = whnf_core(mk_rev_app(instantiate(binder_body(f), m, args.data() + (num_args - m)), num_args - m, args.data()));
break; break;
} else if (is_macro(f)) {
auto m = expand_macro(f, args.size(), args.data());
if (m) {
r = whnf_core(*m);
break;
}
} }
r = is_eqp(f, *it) ? e : mk_rev_app(f, args.size(), args.data()); r = is_eqp(f, *it) ? e : mk_rev_app(f, args.size(), args.data());
break; break;
@ -664,10 +658,13 @@ struct type_checker::imp {
r = instantiate_params(d.get_type(), ps, ls); r = instantiate_params(d.get_type(), ps, ls);
break; break;
} }
case expr_kind::Macro: case expr_kind::Macro: {
r = to_macro(e).get_type(0, 0, 0); buffer<expr> arg_types;
if (!infer_only && to_macro(e).trust_level() <= m_env.trust_lvl()) { for (unsigned i = 0; i < macro_num_args(e); i++)
optional<expr> m = to_macro(e).expand(0, 0); arg_types.push_back(infer_type_core(macro_arg(e, i), infer_only));
r = macro_def(e).get_type(macro_num_args(e), macro_args(e), arg_types.data());
if (!infer_only && macro_def(e).trust_level() <= m_env.trust_lvl()) {
optional<expr> m = expand_macro(e);
if (!m) if (!m)
throw_kernel_exception(m_env, "failed to expand macro", some_expr(e)); throw_kernel_exception(m_env, "failed to expand macro", some_expr(e));
expr t = infer_type_core(*m, infer_only); expr t = infer_type_core(*m, infer_only);
@ -676,6 +673,7 @@ struct type_checker::imp {
throw_kernel_exception(m_env, g_macro_error_msg, some_expr(e)); throw_kernel_exception(m_env, g_macro_error_msg, some_expr(e));
} }
break; break;
}
case expr_kind::Lambda: { case expr_kind::Lambda: {
if (!infer_only) { if (!infer_only) {
expr t = infer_type_core(binder_domain(e), infer_only); expr t = infer_type_core(binder_domain(e), infer_only);

View file

@ -92,7 +92,15 @@ bool is_lt(expr const & a, expr const & b, bool use_hash) {
else else
return is_lt(mlocal_type(a), mlocal_type(b), use_hash); return is_lt(mlocal_type(a), mlocal_type(b), use_hash);
case expr_kind::Macro: case expr_kind::Macro:
return to_macro(a) < to_macro(b); if (macro_def(a) != macro_def(b))
return macro_def(a) < macro_def(b);
if (macro_num_args(a) != macro_num_args(b))
return macro_num_args(a) < macro_num_args(b);
for (unsigned i = 0; i < macro_num_args(a); i++) {
if (macro_arg(a, i) != macro_arg(b, i))
return is_lt(macro_arg(a, i), macro_arg(b, i), use_hash);
}
return false;
} }
lean_unreachable(); // LCOV_EXCL_LINE lean_unreachable(); // LCOV_EXCL_LINE
} }