refactor(kernel): store macro arguments in the macro_expr

Before this commit, we "stored" macro arguments using applications.
This representation had some issues. Suppose we use [m a] to denote a macro
application. In the old representation, ([m a] b) and [m a b] would have
the same representation. Another problem is that some procedures (e.g., type inference)
would not have a clean implementation.

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-04-25 15:02:52 -07:00
parent 9d5ed2bae1
commit 4842ae4fc7
11 changed files with 240 additions and 110 deletions

View file

@ -179,21 +179,21 @@ void expr_let::dealloc(buffer<expr_cell*> & todelete) {
}
expr_let::~expr_let() {}
// Macro attachment
int macro::push_lua(lua_State *) const { return 0; } // NOLINT
void macro::display(std::ostream & out) const { out << get_name(); }
bool macro::operator==(macro const & other) const { return typeid(*this) == typeid(other); }
bool macro::operator<(macro const & other) const {
// Macro definition
int macro_definition::push_lua(lua_State *) const { return 0; } // NOLINT
bool macro_definition::operator==(macro_definition const & other) const { return typeid(*this) == typeid(other); }
bool macro_definition::operator<(macro_definition const & other) const {
if (get_name() == other.get_name())
return lt(other);
else
return get_name() < other.get_name();
}
format macro::pp(formatter const &, options const &) const { return format(get_name()); }
bool macro::is_atomic_pp(bool, bool) const { return true; }
unsigned macro::hash() const { return get_name().hash(); }
format macro_definition::pp(formatter const &, options const &) const { return format(get_name()); }
void macro_definition::display(std::ostream & out) const { out << get_name(); }
bool macro_definition::is_atomic_pp(bool, bool) const { return true; }
unsigned macro_definition::hash() const { return get_name().hash(); }
typedef std::unordered_map<std::string, macro::reader> macro_readers;
typedef std::unordered_map<std::string, macro_definition::reader> macro_readers;
static std::unique_ptr<macro_readers> g_macro_readers;
macro_readers & get_macro_readers() {
if (!g_macro_readers)
@ -201,26 +201,61 @@ macro_readers & get_macro_readers() {
return *(g_macro_readers.get());
}
void macro::register_deserializer(std::string const & k, macro::reader rd) {
void macro_definition::register_deserializer(std::string const & k, macro_definition::reader rd) {
macro_readers & readers = get_macro_readers();
lean_assert(readers.find(k) == readers.end());
readers[k] = rd;
}
static expr read_macro(deserializer & d) {
static expr read_macro_definition(deserializer & d, unsigned num, expr const * args) {
auto k = d.read_string();
macro_readers & readers = get_macro_readers();
auto it = readers.find(k);
lean_assert(it != readers.end());
return it->second(d);
return it->second(d, num, args);
}
expr_macro::expr_macro(macro * m):
expr_cell(expr_kind::Macro, m->hash(), false, false, false),
m_macro(m) {
m_macro->inc_ref();
static unsigned max_depth(unsigned num, expr const * args) {
unsigned r = 0;
for (unsigned i = 0; i < num; i++) {
unsigned d = get_depth(args[i]);
if (d > r)
r = d;
}
return r;
}
static unsigned get_free_var_range(unsigned num, expr const * args) {
unsigned r = 0;
for (unsigned i = 0; i < num; i++) {
unsigned d = get_free_var_range(args[i]);
if (d > r)
r = d;
}
return r;
}
expr_macro::expr_macro(macro_definition * m, unsigned num, expr const * args):
expr_composite(expr_kind::Macro,
lean::hash(num, [&](unsigned i) { return args[i].hash(); }, m->hash()),
std::any_of(args, args+num, [](expr const & e) { return e.has_metavar(); }),
std::any_of(args, args+num, [](expr const & e) { return e.has_local(); }),
std::any_of(args, args+num, [](expr const & e) { return e.has_param_univ(); }),
max_depth(num, args) + 1,
get_free_var_range(num, args)),
m_definition(m),
m_num_args(num) {
m_definition->inc_ref();
m_args = new expr[num];
for (unsigned i = 0; i < m_num_args; i++)
m_args[i] = args[i];
}
void expr_macro::dealloc(buffer<expr_cell*> & todelete) {
m_definition->dec_ref();
for (unsigned i = 0; i < m_num_args; i++) dec_ref(m_args[i], todelete);
delete(this);
}
expr_macro::~expr_macro() {
m_macro->dec_ref();
delete[] m_args;
}
void expr_cell::dealloc() {
@ -233,7 +268,7 @@ void expr_cell::dealloc() {
lean_assert(it->get_rc() == 0);
switch (it->kind()) {
case expr_kind::Var: delete static_cast<expr_var*>(it); break;
case expr_kind::Macro: delete static_cast<expr_macro*>(it); break;
case expr_kind::Macro: static_cast<expr_macro*>(it)->dealloc(todo); break;
case expr_kind::Meta:
case expr_kind::Local: static_cast<expr_mlocal*>(it)->dealloc(todo); break;
case expr_kind::Constant: delete static_cast<expr_const*>(it); break;
@ -290,9 +325,9 @@ expr mk_Type() { return Type; }
unsigned get_depth(expr const & e) {
switch (e.kind()) {
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Sort:
case expr_kind::Meta: case expr_kind::Local: case expr_kind::Macro:
case expr_kind::Meta: case expr_kind::Local:
return 1;
case expr_kind::Lambda: case expr_kind::Pi:
case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Macro:
case expr_kind::App: case expr_kind::Let:
return static_cast<expr_composite*>(e.raw())->m_depth;
}
@ -303,12 +338,13 @@ unsigned get_free_var_range(expr const & e) {
switch (e.kind()) {
case expr_kind::Var:
return var_idx(e) + 1;
case expr_kind::Constant: case expr_kind::Sort: case expr_kind::Macro:
case expr_kind::Constant: case expr_kind::Sort:
return 0;
case expr_kind::Meta: case expr_kind::Local:
return get_free_var_range(mlocal_type(e));
case expr_kind::Lambda: case expr_kind::Pi:
case expr_kind::App: case expr_kind::Let:
case expr_kind::Macro:
return static_cast<expr_composite*>(e.raw())->m_free_var_range;
}
lean_unreachable(); // LCOV_EXCL_LINE
@ -370,11 +406,26 @@ expr update_constant(expr const & e, levels const & new_levels) {
return e;
}
expr update_macro(expr const & e, unsigned num, expr const * args) {
if (num == macro_num_args(e)) {
unsigned i = 0;
for (i = 0; i < num; i++) {
if (!is_eqp(macro_arg(e, i), args[i]))
break;
}
if (i == num)
return e;
}
return mk_macro(to_macro(e)->m_definition, num, args);
}
bool is_atomic(expr const & e) {
switch (e.kind()) {
case expr_kind::Constant: case expr_kind::Sort:
case expr_kind::Macro: case expr_kind::Var:
case expr_kind::Var:
return true;
case expr_kind::Macro:
return to_macro(e)->get_num_args() == 0;
case expr_kind::App: case expr_kind::Let:
case expr_kind::Meta: case expr_kind::Local:
case expr_kind::Lambda: case expr_kind::Pi:
@ -399,7 +450,7 @@ expr copy(expr const & a) {
case expr_kind::Var: return mk_var(var_idx(a));
case expr_kind::Constant: return mk_constant(const_name(a), const_level_params(a));
case expr_kind::Sort: return mk_sort(sort_level(a));
case expr_kind::Macro: return mk_macro(static_cast<expr_macro*>(a.raw())->m_macro);
case expr_kind::Macro: return mk_macro(to_macro(a)->m_definition, macro_num_args(a), macro_args(a));
case expr_kind::App: return mk_app(app_fn(a), app_arg(a));
case expr_kind::Lambda: return mk_lambda(binder_name(a), binder_domain(a), binder_body(a));
case expr_kind::Pi: return mk_pi(binder_name(a), binder_domain(a), binder_body(a));
@ -429,7 +480,11 @@ class expr_serializer : public object_serializer<expr, expr_hash_alloc, expr_eqp
s << sort_level(a);
break;
case expr_kind::Macro:
to_macro(a).write(s);
s << macro_num_args(a);
for (unsigned i = 0; i < macro_num_args(a); i++) {
write_core(macro_arg(a, i));
}
macro_def(a).write(s);
break;
case expr_kind::App:
write_core(app_fn(a)); write_core(app_arg(a));
@ -475,9 +530,14 @@ public:
}
case expr_kind::Sort:
return mk_sort(read_level(d));
break;
case expr_kind::Macro:
return read_macro(d);
case expr_kind::Macro: {
unsigned n = d.read_unsigned();
buffer<expr> args;
for (unsigned i = 0; i < n; i++) {
args.push_back(read());
}
return read_macro_definition(d, args.size(), args.data());
}
case expr_kind::App: {
expr f = read();
return mk_app(f, read());

View file

@ -70,7 +70,7 @@ public:
bool has_param_univ() const { return m_has_param_univ; }
};
class macro;
class macro_definition;
/**
\brief Exprs for encoding formulas/expressions, types and proofs.
@ -119,7 +119,7 @@ public:
friend expr mk_proj(bool fst, expr const & p);
friend expr mk_binder(expr_kind k, name const & n, expr const & t, expr const & e);
friend expr mk_let(name const & n, expr const & t, expr const & v, expr const & e);
friend expr mk_macro(macro * m);
friend expr mk_macro(macro_definition * m, unsigned num, expr const * args);
friend bool is_eqp(expr const & a, expr const & b) { return a.m_ptr == b.m_ptr; }
// Overloaded operator() can be used to create applications
@ -245,8 +245,8 @@ public:
class formatter;
/** \brief Base class for macro attachments */
class macro {
/** \brief Base class for macro definition attachments */
class macro_definition {
void dealloc() { delete this; }
MK_LEAN_RC();
protected:
@ -255,39 +255,48 @@ protected:
attachments. It is invoked by operator<, and it is only invoked when
<tt>get_name() == other.get_name()</tt>
*/
virtual bool lt(macro const &) const { return false; }
virtual bool lt(macro_definition const &) const { return false; }
public:
macro():m_rc(0) {}
virtual ~macro() {}
macro_definition():m_rc(0) {}
virtual ~macro_definition() {}
virtual name get_name() const = 0;
virtual expr get_type(unsigned num_args, expr const * args, expr const * arg_types) const = 0;
virtual optional<expr> expand1(unsigned num_args, expr const * args) const = 0;
virtual optional<expr> expand(unsigned num_args, expr const * args) const = 0;
virtual expr get_type(unsigned num, expr const * args, expr const * arg_types) const = 0;
virtual optional<expr> expand1(unsigned num, expr const * args) const = 0;
virtual optional<expr> expand(unsigned num, expr const * args) const = 0;
virtual unsigned trust_level() const = 0;
virtual int push_lua(lua_State * L) const;
virtual bool operator==(macro const & other) const;
bool operator<(macro const & other) const;
virtual bool operator==(macro_definition const & other) const;
bool operator!=(macro_definition const & other) const { return !operator==(other); }
bool operator<(macro_definition const & other) const;
virtual void display(std::ostream & out) const;
virtual format pp(formatter const & fmt, options const & opts) const;
virtual bool is_atomic_pp(bool unicode, bool coercion) const;
virtual unsigned hash() const;
virtual void write(serializer & s) const = 0;
typedef std::function<expr(deserializer&)> reader;
typedef std::function<expr(deserializer&, unsigned, expr const *)> reader;
static void register_deserializer(std::string const & k, reader rd);
struct register_deserializer_fn {
register_deserializer_fn(std::string const & k, macro::reader rd) { macro::register_deserializer(k, rd); }
register_deserializer_fn(std::string const & k, macro_definition::reader rd) { macro_definition::register_deserializer(k, rd); }
};
};
/** \brief Macro attachments */
class expr_macro : public expr_cell {
macro * m_macro;
class expr_macro : public expr_composite {
macro_definition * m_definition;
unsigned m_num_args;
expr * m_args;
friend class expr_cell;
friend expr copy(expr const & a);
friend expr update_macro(expr const & e, unsigned num, expr const * args);
void dealloc(buffer<expr_cell*> & todelete);
public:
expr_macro(macro * v);
expr_macro(macro_definition * v, unsigned num, expr const * args);
~expr_macro();
macro const & get_macro() const { return *m_macro; }
macro_definition const & get_def() const { return *m_definition; }
expr const * get_args() const { return m_args; }
expr const & get_arg(unsigned idx) const { lean_assert(idx < m_num_args); return m_args[idx]; }
unsigned get_num_args() const { return m_num_args; }
};
// =======================================
@ -331,7 +340,7 @@ inline expr Var(unsigned idx) { return mk_var(idx); }
inline expr mk_constant(name const & n, levels const & ls) { return expr(new expr_const(n, ls)); }
inline expr mk_constant(name const & n) { return mk_constant(n, levels()); }
inline expr Const(name const & n) { return mk_constant(n); }
inline expr mk_macro(macro * m) { return expr(new expr_macro(m)); }
inline expr mk_macro(macro_definition * m, unsigned num = 0, expr const * args = nullptr) { return expr(new expr_macro(m, num, args)); }
inline expr mk_mlocal(bool is_meta, name const & n, expr const & t) { return expr(new expr_mlocal(is_meta, n, t)); }
inline expr mk_metavar(name const & n, expr const & t) { return mk_mlocal(true, n, t); }
inline expr mk_local(name const & n, expr const & t) { return mk_mlocal(false, n, t); }
@ -399,6 +408,7 @@ inline expr_sort * to_sort(expr_cell * e) { lean_assert(is_sort(e))
inline expr_mlocal * to_mlocal(expr_cell * e) { lean_assert(is_mlocal(e)); return static_cast<expr_mlocal*>(e); }
inline expr_mlocal * to_local(expr_cell * e) { lean_assert(is_local(e)); return static_cast<expr_mlocal*>(e); }
inline expr_mlocal * to_metavar(expr_cell * e) { lean_assert(is_metavar(e)); return static_cast<expr_mlocal*>(e); }
inline expr_macro * to_macro(expr_cell * e) { lean_assert(is_macro(e)); return static_cast<expr_macro*>(e); }
inline expr_var * to_var(expr const & e) { return to_var(e.raw()); }
inline expr_const * to_constant(expr const & e) { return to_constant(e.raw()); }
@ -409,51 +419,57 @@ inline expr_sort * to_sort(expr const & e) { return to_sort(e.raw(
inline expr_mlocal * to_mlocal(expr const & e) { return to_mlocal(e.raw()); }
inline expr_mlocal * to_metavar(expr const & e) { return to_metavar(e.raw()); }
inline expr_mlocal * to_local(expr const & e) { return to_local(e.raw()); }
inline expr_macro * to_macro(expr const & e) { return to_macro(e.raw()); }
// =======================================
// =======================================
// Accessors
inline unsigned get_rc(expr_cell * e) { return e->get_rc(); }
inline bool is_shared(expr_cell * e) { return get_rc(e) > 1; }
inline unsigned var_idx(expr_cell * e) { return to_var(e)->get_vidx(); }
inline bool is_var(expr_cell * e, unsigned i) { return is_var(e) && var_idx(e) == i; }
inline name const & const_name(expr_cell * e) { return to_constant(e)->get_name(); }
inline levels const & const_level_params(expr_cell * e) { return to_constant(e)->get_level_params(); }
inline macro const & to_macro(expr_cell * e) {
lean_assert(is_macro(e)); return static_cast<expr_macro*>(e)->get_macro(); }
inline expr const & app_fn(expr_cell * e) { return to_app(e)->get_fn(); }
inline expr const & app_arg(expr_cell * e) { return to_app(e)->get_arg(); }
inline name const & binder_name(expr_cell * e) { return to_binder(e)->get_name(); }
inline expr const & binder_domain(expr_cell * e) { return to_binder(e)->get_domain(); }
inline expr const & binder_body(expr_cell * e) { return to_binder(e)->get_body(); }
inline level const & sort_level(expr_cell * e) { return to_sort(e)->get_level(); }
inline name const & let_name(expr_cell * e) { return to_let(e)->get_name(); }
inline expr const & let_value(expr_cell * e) { return to_let(e)->get_value(); }
inline expr const & let_type(expr_cell * e) { return to_let(e)->get_type(); }
inline expr const & let_body(expr_cell * e) { return to_let(e)->get_body(); }
inline name const & mlocal_name(expr_cell * e) { return to_mlocal(e)->get_name(); }
inline expr const & mlocal_type(expr_cell * e) { return to_mlocal(e)->get_type(); }
inline unsigned get_rc(expr_cell * e) { return e->get_rc(); }
inline bool is_shared(expr_cell * e) { return get_rc(e) > 1; }
inline unsigned var_idx(expr_cell * e) { return to_var(e)->get_vidx(); }
inline bool is_var(expr_cell * e, unsigned i) { return is_var(e) && var_idx(e) == i; }
inline name const & const_name(expr_cell * e) { return to_constant(e)->get_name(); }
inline levels const & const_level_params(expr_cell * e) { return to_constant(e)->get_level_params(); }
inline macro_definition const & macro_def(expr_cell * e) { return to_macro(e)->get_def(); }
inline expr const * macro_args(expr_cell * e) { return to_macro(e)->get_args(); }
inline expr const & macro_arg(expr_cell * e, unsigned i) { return to_macro(e)->get_arg(i); }
inline unsigned macro_num_args(expr_cell * e) { return to_macro(e)->get_num_args(); }
inline expr const & app_fn(expr_cell * e) { return to_app(e)->get_fn(); }
inline expr const & app_arg(expr_cell * e) { return to_app(e)->get_arg(); }
inline name const & binder_name(expr_cell * e) { return to_binder(e)->get_name(); }
inline expr const & binder_domain(expr_cell * e) { return to_binder(e)->get_domain(); }
inline expr const & binder_body(expr_cell * e) { return to_binder(e)->get_body(); }
inline level const & sort_level(expr_cell * e) { return to_sort(e)->get_level(); }
inline name const & let_name(expr_cell * e) { return to_let(e)->get_name(); }
inline expr const & let_value(expr_cell * e) { return to_let(e)->get_value(); }
inline expr const & let_type(expr_cell * e) { return to_let(e)->get_type(); }
inline expr const & let_body(expr_cell * e) { return to_let(e)->get_body(); }
inline name const & mlocal_name(expr_cell * e) { return to_mlocal(e)->get_name(); }
inline expr const & mlocal_type(expr_cell * e) { return to_mlocal(e)->get_type(); }
inline unsigned get_rc(expr const & e) { return e.raw()->get_rc(); }
inline bool is_shared(expr const & e) { return get_rc(e) > 1; }
inline unsigned var_idx(expr const & e) { return to_var(e)->get_vidx(); }
inline bool is_var(expr const & e, unsigned i) { return is_var(e) && var_idx(e) == i; }
inline name const & const_name(expr const & e) { return to_constant(e)->get_name(); }
inline levels const & const_level_params(expr const & e) { return to_constant(e)->get_level_params(); }
inline macro const & to_macro(expr const & e) { return to_macro(e.raw()); }
inline expr const & app_fn(expr const & e) { return to_app(e)->get_fn(); }
inline expr const & app_arg(expr const & e) { return to_app(e)->get_arg(); }
inline name const & binder_name(expr const & e) { return to_binder(e)->get_name(); }
inline expr const & binder_domain(expr const & e) { return to_binder(e)->get_domain(); }
inline expr const & binder_body(expr const & e) { return to_binder(e)->get_body(); }
inline level const & sort_level(expr const & e) { return to_sort(e)->get_level(); }
inline name const & let_name(expr const & e) { return to_let(e)->get_name(); }
inline expr const & let_value(expr const & e) { return to_let(e)->get_value(); }
inline expr const & let_type(expr const & e) { return to_let(e)->get_type(); }
inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); }
inline name const & mlocal_name(expr const & e) { return to_mlocal(e)->get_name(); }
inline expr const & mlocal_type(expr const & e) { return to_mlocal(e)->get_type(); }
inline unsigned get_rc(expr const & e) { return e.raw()->get_rc(); }
inline bool is_shared(expr const & e) { return get_rc(e) > 1; }
inline unsigned var_idx(expr const & e) { return to_var(e)->get_vidx(); }
inline bool is_var(expr const & e, unsigned i) { return is_var(e) && var_idx(e) == i; }
inline name const & const_name(expr const & e) { return to_constant(e)->get_name(); }
inline levels const & const_level_params(expr const & e) { return to_constant(e)->get_level_params(); }
inline macro_definition const & macro_def(expr const & e) { return to_macro(e)->get_def(); }
inline expr const * macro_args(expr const & e) { return to_macro(e)->get_args(); }
inline expr const & macro_arg(expr const & e, unsigned i) { return to_macro(e)->get_arg(i); }
inline unsigned macro_num_args(expr const & e) { return to_macro(e)->get_num_args(); }
inline expr const & app_fn(expr const & e) { return to_app(e)->get_fn(); }
inline expr const & app_arg(expr const & e) { return to_app(e)->get_arg(); }
inline name const & binder_name(expr const & e) { return to_binder(e)->get_name(); }
inline expr const & binder_domain(expr const & e) { return to_binder(e)->get_domain(); }
inline expr const & binder_body(expr const & e) { return to_binder(e)->get_body(); }
inline level const & sort_level(expr const & e) { return to_sort(e)->get_level(); }
inline name const & let_name(expr const & e) { return to_let(e)->get_name(); }
inline expr const & let_value(expr const & e) { return to_let(e)->get_value(); }
inline expr const & let_type(expr const & e) { return to_let(e)->get_type(); }
inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); }
inline name const & mlocal_name(expr const & e) { return to_mlocal(e)->get_name(); }
inline expr const & mlocal_type(expr const & e) { return to_mlocal(e)->get_type(); }
inline bool is_constant(expr const & e, name const & n) { return is_constant(e) && const_name(e) == n; }
inline bool has_metavar(expr const & e) { return e.has_metavar(); }
@ -522,6 +538,7 @@ expr update_let(expr const & e, expr const & new_type, expr const & new_val, exp
expr update_mlocal(expr const & e, expr const & new_type);
expr update_sort(expr const & e, level const & new_level);
expr update_constant(expr const & e, levels const & new_levels);
expr update_macro(expr const & e, unsigned num, expr const * args);
// =======================================
// =======================================

View file

@ -43,7 +43,13 @@ bool expr_eq_fn::apply(expr const & a, expr const & b) {
case expr_kind::Sort:
return sort_level(a) == sort_level(b);
case expr_kind::Macro:
return to_macro(a) == to_macro(b);
if (macro_def(a) != macro_def(b) || macro_num_args(a) != macro_num_args(b))
return false;
for (unsigned i = 0; i < macro_num_args(a); i++) {
if (!apply(macro_arg(a, i), macro_arg(b, i)))
return false;
}
return true;
case expr_kind::Let:
return
apply(let_type(a), let_type(b)) &&

View file

@ -22,7 +22,7 @@ void for_each_fn::apply(expr const & e, unsigned offset) {
switch (e.kind()) {
case expr_kind::Constant: case expr_kind::Var:
case expr_kind::Sort: case expr_kind::Macro:
case expr_kind::Sort:
m_f(e, offset);
goto begin_loop;
default:
@ -43,11 +43,19 @@ void for_each_fn::apply(expr const & e, unsigned offset) {
switch (e.kind()) {
case expr_kind::Constant: case expr_kind::Var:
case expr_kind::Sort: case expr_kind::Macro:
case expr_kind::Sort:
goto begin_loop;
case expr_kind::Meta: case expr_kind::Local:
todo.emplace_back(mlocal_type(e), offset);
goto begin_loop;
case expr_kind::Macro: {
unsigned i = macro_num_args(e);
while (i > 0) {
--i;
todo.emplace_back(macro_arg(e, i), offset);
}
goto begin_loop;
}
case expr_kind::App:
todo.emplace_back(app_arg(e), offset);
todo.emplace_back(app_fn(e), offset);

View file

@ -29,8 +29,13 @@ struct print_expr_fn {
}
}
void print_macro(expr const & a) {
to_macro(a).display(out());
void print_macro(expr const & a, context const & c) {
if (macro_num_args(a) > 0) out() << "(";
macro_def(a).display(out());
for (unsigned i = 0; i < macro_num_args(a); i++) {
out() << " "; print(macro_arg(a, i), c);
}
if (macro_num_args(a) > 0) out() << ")";
}
void print_sort(expr const & a) {
@ -109,7 +114,7 @@ struct print_expr_fn {
print_sort(a);
break;
case expr_kind::Macro:
print_macro(a);
print_macro(a, c);
break;
}
}

View file

@ -57,7 +57,7 @@ protected:
bool result = false;
switch (e.kind()) {
case expr_kind::Constant: case expr_kind::Sort: case expr_kind::Macro:
case expr_kind::Constant: case expr_kind::Sort:
case expr_kind::Var:
lean_unreachable(); // LCOV_EXCL_LINE
case expr_kind::Meta: case expr_kind::Local:
@ -72,6 +72,14 @@ protected:
case expr_kind::Let:
result = apply(let_type(e), offset) || apply(let_value(e), offset) || apply(let_body(e), offset + 1);
break;
case expr_kind::Macro:
for (unsigned i = 0; i < macro_num_args(e); i++) {
if (apply(macro_arg(e, i), offset)) {
result = true;
break;
}
}
break;
}
if (!result && shared) {

View file

@ -29,7 +29,7 @@ struct max_sharing_fn::imp {
expr res;
switch (a.kind()) {
case expr_kind::Constant: case expr_kind::Var:
case expr_kind::Sort: case expr_kind::Macro:
case expr_kind::Sort:
res = a;
break;
case expr_kind::App:
@ -44,7 +44,13 @@ struct max_sharing_fn::imp {
case expr_kind::Meta: case expr_kind::Local:
res = update_mlocal(a, apply(mlocal_type(a)));
break;
}
case expr_kind::Macro: {
buffer<expr> new_args;
for (unsigned i = 0; i < macro_num_args(a); i++)
new_args.push_back(macro_arg(a, i));
res = update_macro(a, new_args.size(), new_args.data());
break;
}}
m_cache.insert(res);
return res;
}

View file

@ -77,7 +77,7 @@ expr replace_fn::operator()(expr const & e) {
unsigned offset = f.m_offset;
switch (e.kind()) {
case expr_kind::Constant: case expr_kind::Sort:
case expr_kind::Macro: case expr_kind::Var:
case expr_kind::Var:
lean_unreachable(); // LCOV_EXCL_LINE
case expr_kind::Meta: case expr_kind::Local:
if (check_index(f, 0) && !visit(mlocal_type(e), offset))
@ -111,6 +111,14 @@ expr replace_fn::operator()(expr const & e) {
r = update_let(e, rs(-3), rs(-2), rs(-1));
pop_rs(3);
break;
case expr_kind::Macro:
while (f.m_index < macro_num_args(e)) {
if (!visit(macro_arg(e, f.m_index), offset))
goto begin_loop;
}
r = update_macro(e, macro_num_args(e), &rs(-macro_num_args(e)));
pop_rs(macro_num_args(e));
break;
}
save_result(e, r, offset, f.m_shared);
m_fs.pop_back();

View file

@ -11,7 +11,6 @@ Author: Leonardo de Moura
namespace lean {
expr replace_visitor::visit_sort(expr const & e, context const &) { lean_assert(is_sort(e)); return e; }
expr replace_visitor::visit_macro(expr const & e, context const &) { lean_assert(is_macro(e)); return e; }
expr replace_visitor::visit_var(expr const & e, context const &) { lean_assert(is_var(e)); return e; }
expr replace_visitor::visit_constant(expr const & e, context const &) { lean_assert(is_constant(e)); return e; }
expr replace_visitor::visit_mlocal(expr const & e, context const & ctx) {
@ -41,6 +40,13 @@ expr replace_visitor::visit_let(expr const & e, context const & ctx) {
expr new_b = visit(let_body(e), extend(ctx, let_name(e), new_t));
return update_let(e, new_t, new_v, new_b);
}
expr replace_visitor::visit_macro(expr const & e, context const & ctx) {
lean_assert(is_macro(e));
buffer<expr> new_args;
for (unsigned i = 0; i < macro_num_args(e); i++)
new_args.push_back(visit(macro_arg(e, i), ctx));
return update_macro(e, new_args.size(), new_args.data());
}
expr replace_visitor::save_result(expr const & e, expr && r, bool shared) {
if (shared)
m_cache.insert(std::make_pair(e, r));

View file

@ -56,9 +56,9 @@ struct type_checker::imp {
expr instantiate(expr const & e, unsigned n, expr const * s) { return max_sharing(lean::instantiate(e, n, s)); }
expr instantiate(expr const & e, expr const & s) { return max_sharing(lean::instantiate(e, s)); }
expr mk_rev_app(expr const & f, unsigned num, expr const * args) { return max_sharing(lean::mk_rev_app(f, num, args)); }
optional<expr> expand_macro(expr const & m, unsigned num, expr const * args) {
optional<expr> expand_macro(expr const & m) {
lean_assert(is_macro(m));
if (auto new_m = to_macro(m).expand(num, args))
if (auto new_m = macro_def(m).expand(macro_num_args(m), macro_args(m)))
return some_expr(max_sharing(*new_m));
else
return none_expr();
@ -134,7 +134,7 @@ struct type_checker::imp {
case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Constant:
lean_unreachable(); // LCOV_EXCL_LINE
case expr_kind::Macro:
if (auto m = expand_macro(e, 0, 0))
if (auto m = expand_macro(e))
r = whnf_core(*m);
else
r = e;
@ -160,12 +160,6 @@ struct type_checker::imp {
lean_assert(m <= num_args);
r = whnf_core(mk_rev_app(instantiate(binder_body(f), m, args.data() + (num_args - m)), num_args - m, args.data()));
break;
} else if (is_macro(f)) {
auto m = expand_macro(f, args.size(), args.data());
if (m) {
r = whnf_core(*m);
break;
}
}
r = is_eqp(f, *it) ? e : mk_rev_app(f, args.size(), args.data());
break;
@ -664,10 +658,13 @@ struct type_checker::imp {
r = instantiate_params(d.get_type(), ps, ls);
break;
}
case expr_kind::Macro:
r = to_macro(e).get_type(0, 0, 0);
if (!infer_only && to_macro(e).trust_level() <= m_env.trust_lvl()) {
optional<expr> m = to_macro(e).expand(0, 0);
case expr_kind::Macro: {
buffer<expr> arg_types;
for (unsigned i = 0; i < macro_num_args(e); i++)
arg_types.push_back(infer_type_core(macro_arg(e, i), infer_only));
r = macro_def(e).get_type(macro_num_args(e), macro_args(e), arg_types.data());
if (!infer_only && macro_def(e).trust_level() <= m_env.trust_lvl()) {
optional<expr> m = expand_macro(e);
if (!m)
throw_kernel_exception(m_env, "failed to expand macro", some_expr(e));
expr t = infer_type_core(*m, infer_only);
@ -676,6 +673,7 @@ struct type_checker::imp {
throw_kernel_exception(m_env, g_macro_error_msg, some_expr(e));
}
break;
}
case expr_kind::Lambda: {
if (!infer_only) {
expr t = infer_type_core(binder_domain(e), infer_only);

View file

@ -92,7 +92,15 @@ bool is_lt(expr const & a, expr const & b, bool use_hash) {
else
return is_lt(mlocal_type(a), mlocal_type(b), use_hash);
case expr_kind::Macro:
return to_macro(a) < to_macro(b);
if (macro_def(a) != macro_def(b))
return macro_def(a) < macro_def(b);
if (macro_num_args(a) != macro_num_args(b))
return macro_num_args(a) < macro_num_args(b);
for (unsigned i = 0; i < macro_num_args(a); i++) {
if (macro_arg(a, i) != macro_arg(b, i))
return is_lt(macro_arg(a, i), macro_arg(b, i), use_hash);
}
return false;
}
lean_unreachable(); // LCOV_EXCL_LINE
}