Refactor elaborator for supporting overloads

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-09-01 10:24:10 -07:00
parent b2924bba99
commit 598daa40bc
5 changed files with 175 additions and 84 deletions

View file

@ -377,9 +377,8 @@ class parser::imp {
/** /**
\brief Return the function associated with the given operator. \brief Return the function associated with the given operator.
If the operator has been overloaded, it returns an expression If the operator has been overloaded, it returns a choice expression
of the form (overload f_k ... (overload f_2 f_1) ...) of the form <tt>(choice f_1 f_2 ... f_k)</tt> where f_i's are different options.
where f_i's are different options.
After we finish parsing, the procedure #elaborate will After we finish parsing, the procedure #elaborate will
resolve/decide which f_i should be used. resolve/decide which f_i should be used.
*/ */
@ -389,9 +388,15 @@ class parser::imp {
auto it = fs.begin(); auto it = fs.begin();
expr r = *it; expr r = *it;
++it; ++it;
for (; it != fs.end(); ++it) if (it == fs.end()) {
r = mk_app(mk_overload_marker(), *it, r); return r;
return r; } else {
buffer<expr> alternatives;
alternatives.push_back(r);
for (; it != fs.end(); ++it)
alternatives.push_back(*it);
return mk_choice(alternatives.size(), alternatives.data());
}
} }
/** /**

View file

@ -17,17 +17,28 @@ Author: Leonardo de Moura
#include "elaborator_exception.h" #include "elaborator_exception.h"
namespace lean { namespace lean {
static name g_overload_name(name(name(name(0u), "library"), "overload")); static name g_choice_name(name(name(name(0u), "library"), "choice"));
static expr g_overload = mk_constant(g_overload_name); static expr g_choice = mk_constant(g_choice_name);
static format g_assignment_fmt = format(":="); static format g_assignment_fmt = format(":=");
static format g_unification_fmt = format("\u2248"); static format g_unification_fmt = format("\u2248");
bool is_overload_marker(expr const & e) { expr mk_choice(unsigned num_fs, expr const * fs) {
return e == g_overload; lean_assert(num_fs >= 2);
return mk_eq(g_choice, mk_app(num_fs, fs));
} }
expr mk_overload_marker() { bool is_choice(expr const & e) {
return g_overload; return is_eq(e) && eq_lhs(e) == g_choice;
}
unsigned get_num_choices(expr const & e) {
lean_assert(is_choice(e));
return num_args(eq_rhs(e));
}
expr const & get_choice(expr const & e, unsigned i) {
lean_assert(is_choice(e));
return arg(eq_rhs(e), i);
} }
class elaborator::imp { class elaborator::imp {
@ -82,13 +93,22 @@ class elaborator::imp {
volatile bool m_interrupted; volatile bool m_interrupted;
expr mk_metavar(context const & ctx) {
unsigned midx = m_metavars.size();
expr r = ::lean::mk_metavar(midx);
m_metavars.push_back(metavar_info());
m_metavars[midx].m_mvar = r;
m_metavars[midx].m_ctx = ctx;
return r;
}
expr metavar_type(expr const & m) { expr metavar_type(expr const & m) {
lean_assert(is_metavar(m)); lean_assert(is_metavar(m));
unsigned midx = metavar_idx(m); unsigned midx = metavar_idx(m);
if (m_metavars[midx].m_type) { if (m_metavars[midx].m_type) {
return m_metavars[midx].m_type; return m_metavars[midx].m_type;
} else { } else {
expr t = mk_metavar(); expr t = mk_metavar(m_metavars[midx].m_ctx);
m_metavars[midx].m_type = t; m_metavars[midx].m_type = t;
return t; return t;
} }
@ -163,67 +183,139 @@ class elaborator::imp {
} }
} }
expr infer(expr const & e, context const & ctx) { typedef std::pair<expr, expr> expr_pair;
/**
\brief Traverse the expression \c e, and compute
1- A new expression that does not contain choice expressions,
coercions have been added when appropriate, and placeholders
have been replaced with metavariables.
2- The type of \c e.
It also populates m_constraints with a set of constraints that
need to be solved to infer the value of the metavariables.
*/
expr_pair process(expr const & e, context const & ctx) {
check_interrupted(m_interrupted); check_interrupted(m_interrupted);
switch (e.kind()) { switch (e.kind()) {
case expr_kind::Constant: case expr_kind::Constant:
if (is_metavar(e)) { if (is_placeholder(e)) {
unsigned midx = metavar_idx(e); expr m = mk_metavar(ctx);
if (!(m_metavars[midx].m_ctx)) { m_trace[m] = e;
lean_assert(!(m_metavars[midx].m_mvar)); return expr_pair(m, metavar_type(m));
m_metavars[midx].m_mvar = e; } else if (is_metavar(e)) {
m_metavars[midx].m_ctx = ctx; return expr_pair(e, metavar_type(e));
}
return metavar_type(e);
} else { } else {
return m_env.get_object(const_name(e)).get_type(); return expr_pair(e, m_env.get_object(const_name(e)).get_type());
} }
case expr_kind::Var: case expr_kind::Var:
return lookup(ctx, var_idx(e)); return expr_pair(e, lookup(ctx, var_idx(e)));
case expr_kind::Type: case expr_kind::Type:
return mk_type(ty_level(e) + 1); return expr_pair(e, mk_type(ty_level(e) + 1));
case expr_kind::Value: case expr_kind::Value:
return to_value(e).get_type(); return expr_pair(e, to_value(e).get_type());
case expr_kind::App: { case expr_kind::App: {
buffer<expr> args;
buffer<expr> types; buffer<expr> types;
buffer<expr> f_choices;
buffer<expr> f_choice_types;
unsigned num = num_args(e); unsigned num = num_args(e);
for (unsigned i = 0; i < num; i++) { unsigned i = 0;
types.push_back(infer(arg(e,i), ctx)); bool modified = false;
} expr const & f = arg(e, 0);
// TODO: handle overload in args[0] if (is_metavar(f)) {
expr f_t = types[0];
if (!f_t) {
throw invalid_placeholder_exception(*m_owner, ctx, e); throw invalid_placeholder_exception(*m_owner, ctx, e);
} else if (is_choice(f)) {
unsigned num_alts = get_num_choices(f);
for (unsigned j = 0; j < num_alts; j++) {
auto p = process(get_choice(f, j), ctx);
f_choices.push_back(p.first);
f_choice_types.push_back(p.second);
}
args.push_back(expr()); // placeholder
types.push_back(expr()); // placeholder
modified = true;
i++;
} }
for (; i < num; i++) {
expr const & a_i = arg(e, i);
auto p = process(a_i, ctx);
if (!is_eqp(p.first, a_i))
modified = true;
args.push_back(p.first);
types.push_back(p.second);
}
// TODO: choose an f from f_choices
expr f_t = types[0];
for (unsigned i = 1; i < num; i++) { for (unsigned i = 1; i < num; i++) {
f_t = check_pi(f_t, ctx, e, ctx); f_t = check_pi(f_t, ctx, e, ctx);
if (m_add_constraints) if (m_add_constraints)
add_constraint(abst_domain(f_t), types[i], ctx, e, i); add_constraint(abst_domain(f_t), types[i], ctx, e, i);
f_t = instantiate_free_var_mmv(abst_body(f_t), 0, arg(e, i)); f_t = instantiate_free_var_mmv(abst_body(f_t), 0, args[i]);
}
if (modified) {
expr new_e = mk_app(args.size(), args.data());
m_trace[new_e] = e;
return expr_pair(new_e, f_t);
} else {
return expr_pair(e, f_t);
} }
return f_t;
} }
case expr_kind::Eq: { case expr_kind::Eq: {
infer(eq_lhs(e), ctx); auto lhs_p = process(eq_lhs(e), ctx);
infer(eq_rhs(e), ctx); auto rhs_p = process(eq_rhs(e), ctx);
return mk_bool_type(); if (is_eqp(lhs_p.first, eq_lhs(e)) && is_eqp(rhs_p.first, eq_rhs(e))) {
return expr_pair(e, mk_bool_type());
} else {
expr new_e = mk_eq(lhs_p.first, rhs_p.first);
m_trace[new_e] = e;
return expr_pair(new_e, mk_bool_type());
}
} }
case expr_kind::Pi: { case expr_kind::Pi: {
expr dt = infer(abst_domain(e), ctx); auto d_p = process(abst_domain(e), ctx);
expr bt = infer(abst_body(e), extend(ctx, abst_name(e), abst_domain(e))); auto b_p = process(abst_body(e), extend(ctx, abst_name(e), d_p.first));
return mk_type(max(check_universe(dt, ctx, e, ctx), check_universe(bt, ctx, e, ctx))); expr t = mk_type(max(check_universe(d_p.second, ctx, e, ctx), check_universe(b_p.second, ctx, e, ctx)));
if (is_eqp(d_p.first, abst_domain(e)) && is_eqp(b_p.first, abst_body(e))) {
return expr_pair(e, t);
} else {
expr new_e = mk_pi(abst_name(e), d_p.first, b_p.first);
m_trace[new_e] = e;
return expr_pair(new_e, t);
}
} }
case expr_kind::Lambda: { case expr_kind::Lambda: {
expr dt = infer(abst_domain(e), ctx); auto d_p = process(abst_domain(e), ctx);
expr bt = infer(abst_body(e), extend(ctx, abst_name(e), abst_domain(e))); auto b_p = process(abst_body(e), extend(ctx, abst_name(e), d_p.first));
return mk_pi(abst_name(e), abst_domain(e), bt); expr t = mk_pi(abst_name(e), d_p.first, b_p.second);
if (is_eqp(d_p.first, abst_domain(e)) && is_eqp(b_p.first, abst_body(e))) {
return expr_pair(e, t);
} else {
expr new_e = mk_lambda(abst_name(e), d_p.first, b_p.first);
m_trace[new_e] = e;
return expr_pair(new_e, t);
}
} }
case expr_kind::Let: { case expr_kind::Let: {
expr lt = infer(let_value(e), ctx); auto v_p = process(let_value(e), ctx);
return lower_free_vars_mmv(infer(let_body(e), extend(ctx, let_name(e), lt, let_value(e))), 1, 1); auto b_p = process(let_body(e), extend(ctx, let_name(e), v_p.second, v_p.first));
expr t = lower_free_vars_mmv(b_p.second, 1, 1);
if (is_eqp(v_p.first, let_value(e)) && is_eqp(b_p.first, let_body(e))) {
return expr_pair(e, t);
} else {
expr new_e = mk_let(let_name(e), v_p.first, b_p.first);
m_trace[new_e] = e;
return expr_pair(new_e, t);
}
}} }}
lean_unreachable(); lean_unreachable();
return expr(); return expr_pair(expr(), expr());
}
expr infer(expr const & e, context const & ctx) {
return process(e, ctx).second;
} }
bool is_simple_ho_match(expr const & e1, expr const & e2, context const & ctx) { bool is_simple_ho_match(expr const & e1, expr const & e2, context const & ctx) {
@ -454,7 +546,8 @@ class elaborator::imp {
return replacer(e); return replacer(e);
} }
void solve(unsigned num_meta) { void solve() {
unsigned num_meta = m_metavars.size();
m_add_constraints = false; m_add_constraints = false;
while (true) { while (true) {
solve_core(); solve_core();
@ -493,24 +586,6 @@ class elaborator::imp {
} }
} }
expr mk_metavars(expr const & e) {
// replace placeholders with fresh metavars
auto proc = [&](expr const & n, unsigned offset) -> expr {
if (is_placeholder(n)) {
return mk_metavar();
} else {
return n;
}
};
auto tracer = [&](expr const & old_e, expr const & new_e) {
if (!is_eqp(new_e, old_e)) {
m_trace[new_e] = old_e;
}
};
replace_fn<decltype(proc), decltype(tracer)> replacer(proc, tracer);
return replacer(e);
}
public: public:
imp(environment const & env, name_set const * defs): imp(environment const & env, name_set const * defs):
m_env(env), m_env(env),
@ -519,13 +594,6 @@ public:
m_owner = nullptr; m_owner = nullptr;
} }
expr mk_metavar() {
unsigned midx = m_metavars.size();
expr r = ::lean::mk_metavar(midx);
m_metavars.push_back(metavar_info());
return r;
}
void clear() { void clear() {
m_trace.clear(); m_trace.clear();
} }
@ -560,12 +628,10 @@ public:
if (has_placeholder(e)) { if (has_placeholder(e)) {
m_constraints.clear(); m_constraints.clear();
m_metavars.clear(); m_metavars.clear();
m_root = mk_metavars(e);
m_owner = &elb; m_owner = &elb;
unsigned num_meta = m_metavars.size();
m_add_constraints = true; m_add_constraints = true;
infer(m_root, context()); m_root = process(e, context()).first;
solve(num_meta); solve();
return instantiate(m_root); return instantiate(m_root);
} else { } else {
return e; return e;
@ -607,7 +673,6 @@ public:
}; };
elaborator::elaborator(environment const & env):m_ptr(new imp(env, nullptr)) {} elaborator::elaborator(environment const & env):m_ptr(new imp(env, nullptr)) {}
elaborator::~elaborator() {} elaborator::~elaborator() {}
expr elaborator::mk_metavar() { return m_ptr->mk_metavar(); }
expr elaborator::operator()(expr const & e) { return (*m_ptr)(e, *this); } expr elaborator::operator()(expr const & e) { return (*m_ptr)(e, *this); }
expr const & elaborator::get_original(expr const & e) const { return m_ptr->get_original(e); } expr const & elaborator::get_original(expr const & e) const { return m_ptr->get_original(e); }
void elaborator::set_interrupt(bool flag) { m_ptr->set_interrupt(flag); } void elaborator::set_interrupt(bool flag) { m_ptr->set_interrupt(flag); }

View file

@ -23,8 +23,6 @@ public:
explicit elaborator(environment const & env); explicit elaborator(environment const & env);
~elaborator(); ~elaborator();
expr mk_metavar();
expr operator()(expr const & e); expr operator()(expr const & e);
/** /**
@ -45,8 +43,30 @@ public:
void display(std::ostream & out) const; void display(std::ostream & out) const;
format pp(formatter & f, options const & o) const; format pp(formatter & f, options const & o) const;
}; };
/** \brief Return true iff \c e is a special constant used to mark application of overloads. */ /**
bool is_overload_marker(expr const & e); \brief Create a choice expression for the given functions.
/** \brief Return the overload marker */ It is used to mark which functions can be used in a particular application.
expr mk_overload_marker(); The elaborator decides which one should be used based on the type of the arguments.
\pre num_fs >= 2
*/
expr mk_choice(unsigned num_fs, expr const * fs);
/**
\brief Return true iff \c e is an expression created using \c mk_choice.
*/
bool is_choice(expr const & e);
/**
\brief Return the number of alternatives in a choice expression.
We have that <tt>get_num_choices(mk_choice(n, fs)) == n</tt>.
\pre is_choice(e)
*/
unsigned get_num_choices(expr const & e);
/**
\brief Return the (i+1)-th alternative of a choice expression.
\pre is_choice(e)
\pre i < get_num_choices(e)
*/
expr const & get_choice(expr const & e, unsigned i);
} }

View file

@ -107,6 +107,7 @@ bool is_subst(expr const & e) {
} }
expr mk_lift_fn(unsigned s, unsigned n) { expr mk_lift_fn(unsigned s, unsigned n) {
lean_assert(n > 0);
return mk_constant(name(name(g_lift_prefix, s), n)); return mk_constant(name(name(g_lift_prefix, s), n));
} }

View file

@ -5,7 +5,7 @@ Error (line: 4, pos: 40) application type mismatch during term elaboration at te
Elaborator state Elaborator state
?M0 := [unassigned] ?M0 := [unassigned]
?M1 := [unassigned] ?M1 := [unassigned]
#0 ≈ lift:0:0 ?M0 #0 ≈ lift:0:2 ?M0
Assumed: myeq Assumed: myeq
myeq (Π (A : Type) (a : A), A) (λ (A : Type) (a : A), a) (λ (B : Type) (b : B), b) myeq (Π (A : Type) (a : A), A) (λ (A : Type) (a : A), a) (λ (B : Type) (b : B), b)
Bool Bool