Refactor elaborator for supporting overloads

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-09-01 10:24:10 -07:00
parent b2924bba99
commit 598daa40bc
5 changed files with 175 additions and 84 deletions

View file

@ -377,9 +377,8 @@ class parser::imp {
/**
\brief Return the function associated with the given operator.
If the operator has been overloaded, it returns an expression
of the form (overload f_k ... (overload f_2 f_1) ...)
where f_i's are different options.
If the operator has been overloaded, it returns a choice expression
of the form <tt>(choice f_1 f_2 ... f_k)</tt> where f_i's are different options.
After we finish parsing, the procedure #elaborate will
resolve/decide which f_i should be used.
*/
@ -389,9 +388,15 @@ class parser::imp {
auto it = fs.begin();
expr r = *it;
++it;
for (; it != fs.end(); ++it)
r = mk_app(mk_overload_marker(), *it, r);
if (it == fs.end()) {
return r;
} else {
buffer<expr> alternatives;
alternatives.push_back(r);
for (; it != fs.end(); ++it)
alternatives.push_back(*it);
return mk_choice(alternatives.size(), alternatives.data());
}
}
/**

View file

@ -17,17 +17,28 @@ Author: Leonardo de Moura
#include "elaborator_exception.h"
namespace lean {
static name g_overload_name(name(name(name(0u), "library"), "overload"));
static expr g_overload = mk_constant(g_overload_name);
static name g_choice_name(name(name(name(0u), "library"), "choice"));
static expr g_choice = mk_constant(g_choice_name);
static format g_assignment_fmt = format(":=");
static format g_unification_fmt = format("\u2248");
bool is_overload_marker(expr const & e) {
return e == g_overload;
expr mk_choice(unsigned num_fs, expr const * fs) {
lean_assert(num_fs >= 2);
return mk_eq(g_choice, mk_app(num_fs, fs));
}
expr mk_overload_marker() {
return g_overload;
bool is_choice(expr const & e) {
return is_eq(e) && eq_lhs(e) == g_choice;
}
unsigned get_num_choices(expr const & e) {
lean_assert(is_choice(e));
return num_args(eq_rhs(e));
}
expr const & get_choice(expr const & e, unsigned i) {
lean_assert(is_choice(e));
return arg(eq_rhs(e), i);
}
class elaborator::imp {
@ -82,13 +93,22 @@ class elaborator::imp {
volatile bool m_interrupted;
expr mk_metavar(context const & ctx) {
unsigned midx = m_metavars.size();
expr r = ::lean::mk_metavar(midx);
m_metavars.push_back(metavar_info());
m_metavars[midx].m_mvar = r;
m_metavars[midx].m_ctx = ctx;
return r;
}
expr metavar_type(expr const & m) {
lean_assert(is_metavar(m));
unsigned midx = metavar_idx(m);
if (m_metavars[midx].m_type) {
return m_metavars[midx].m_type;
} else {
expr t = mk_metavar();
expr t = mk_metavar(m_metavars[midx].m_ctx);
m_metavars[midx].m_type = t;
return t;
}
@ -163,67 +183,139 @@ class elaborator::imp {
}
}
expr infer(expr const & e, context const & ctx) {
typedef std::pair<expr, expr> expr_pair;
/**
\brief Traverse the expression \c e, and compute
1- A new expression that does not contain choice expressions,
coercions have been added when appropriate, and placeholders
have been replaced with metavariables.
2- The type of \c e.
It also populates m_constraints with a set of constraints that
need to be solved to infer the value of the metavariables.
*/
expr_pair process(expr const & e, context const & ctx) {
check_interrupted(m_interrupted);
switch (e.kind()) {
case expr_kind::Constant:
if (is_metavar(e)) {
unsigned midx = metavar_idx(e);
if (!(m_metavars[midx].m_ctx)) {
lean_assert(!(m_metavars[midx].m_mvar));
m_metavars[midx].m_mvar = e;
m_metavars[midx].m_ctx = ctx;
}
return metavar_type(e);
if (is_placeholder(e)) {
expr m = mk_metavar(ctx);
m_trace[m] = e;
return expr_pair(m, metavar_type(m));
} else if (is_metavar(e)) {
return expr_pair(e, metavar_type(e));
} else {
return m_env.get_object(const_name(e)).get_type();
return expr_pair(e, m_env.get_object(const_name(e)).get_type());
}
case expr_kind::Var:
return lookup(ctx, var_idx(e));
return expr_pair(e, lookup(ctx, var_idx(e)));
case expr_kind::Type:
return mk_type(ty_level(e) + 1);
return expr_pair(e, mk_type(ty_level(e) + 1));
case expr_kind::Value:
return to_value(e).get_type();
return expr_pair(e, to_value(e).get_type());
case expr_kind::App: {
buffer<expr> args;
buffer<expr> types;
buffer<expr> f_choices;
buffer<expr> f_choice_types;
unsigned num = num_args(e);
for (unsigned i = 0; i < num; i++) {
types.push_back(infer(arg(e,i), ctx));
}
// TODO: handle overload in args[0]
expr f_t = types[0];
if (!f_t) {
unsigned i = 0;
bool modified = false;
expr const & f = arg(e, 0);
if (is_metavar(f)) {
throw invalid_placeholder_exception(*m_owner, ctx, e);
} else if (is_choice(f)) {
unsigned num_alts = get_num_choices(f);
for (unsigned j = 0; j < num_alts; j++) {
auto p = process(get_choice(f, j), ctx);
f_choices.push_back(p.first);
f_choice_types.push_back(p.second);
}
args.push_back(expr()); // placeholder
types.push_back(expr()); // placeholder
modified = true;
i++;
}
for (; i < num; i++) {
expr const & a_i = arg(e, i);
auto p = process(a_i, ctx);
if (!is_eqp(p.first, a_i))
modified = true;
args.push_back(p.first);
types.push_back(p.second);
}
// TODO: choose an f from f_choices
expr f_t = types[0];
for (unsigned i = 1; i < num; i++) {
f_t = check_pi(f_t, ctx, e, ctx);
if (m_add_constraints)
add_constraint(abst_domain(f_t), types[i], ctx, e, i);
f_t = instantiate_free_var_mmv(abst_body(f_t), 0, arg(e, i));
f_t = instantiate_free_var_mmv(abst_body(f_t), 0, args[i]);
}
if (modified) {
expr new_e = mk_app(args.size(), args.data());
m_trace[new_e] = e;
return expr_pair(new_e, f_t);
} else {
return expr_pair(e, f_t);
}
return f_t;
}
case expr_kind::Eq: {
infer(eq_lhs(e), ctx);
infer(eq_rhs(e), ctx);
return mk_bool_type();
auto lhs_p = process(eq_lhs(e), ctx);
auto rhs_p = process(eq_rhs(e), ctx);
if (is_eqp(lhs_p.first, eq_lhs(e)) && is_eqp(rhs_p.first, eq_rhs(e))) {
return expr_pair(e, mk_bool_type());
} else {
expr new_e = mk_eq(lhs_p.first, rhs_p.first);
m_trace[new_e] = e;
return expr_pair(new_e, mk_bool_type());
}
}
case expr_kind::Pi: {
expr dt = infer(abst_domain(e), ctx);
expr bt = infer(abst_body(e), extend(ctx, abst_name(e), abst_domain(e)));
return mk_type(max(check_universe(dt, ctx, e, ctx), check_universe(bt, ctx, e, ctx)));
auto d_p = process(abst_domain(e), ctx);
auto b_p = process(abst_body(e), extend(ctx, abst_name(e), d_p.first));
expr t = mk_type(max(check_universe(d_p.second, ctx, e, ctx), check_universe(b_p.second, ctx, e, ctx)));
if (is_eqp(d_p.first, abst_domain(e)) && is_eqp(b_p.first, abst_body(e))) {
return expr_pair(e, t);
} else {
expr new_e = mk_pi(abst_name(e), d_p.first, b_p.first);
m_trace[new_e] = e;
return expr_pair(new_e, t);
}
}
case expr_kind::Lambda: {
expr dt = infer(abst_domain(e), ctx);
expr bt = infer(abst_body(e), extend(ctx, abst_name(e), abst_domain(e)));
return mk_pi(abst_name(e), abst_domain(e), bt);
auto d_p = process(abst_domain(e), ctx);
auto b_p = process(abst_body(e), extend(ctx, abst_name(e), d_p.first));
expr t = mk_pi(abst_name(e), d_p.first, b_p.second);
if (is_eqp(d_p.first, abst_domain(e)) && is_eqp(b_p.first, abst_body(e))) {
return expr_pair(e, t);
} else {
expr new_e = mk_lambda(abst_name(e), d_p.first, b_p.first);
m_trace[new_e] = e;
return expr_pair(new_e, t);
}
}
case expr_kind::Let: {
expr lt = infer(let_value(e), ctx);
return lower_free_vars_mmv(infer(let_body(e), extend(ctx, let_name(e), lt, let_value(e))), 1, 1);
auto v_p = process(let_value(e), ctx);
auto b_p = process(let_body(e), extend(ctx, let_name(e), v_p.second, v_p.first));
expr t = lower_free_vars_mmv(b_p.second, 1, 1);
if (is_eqp(v_p.first, let_value(e)) && is_eqp(b_p.first, let_body(e))) {
return expr_pair(e, t);
} else {
expr new_e = mk_let(let_name(e), v_p.first, b_p.first);
m_trace[new_e] = e;
return expr_pair(new_e, t);
}
}}
lean_unreachable();
return expr();
return expr_pair(expr(), expr());
}
expr infer(expr const & e, context const & ctx) {
return process(e, ctx).second;
}
bool is_simple_ho_match(expr const & e1, expr const & e2, context const & ctx) {
@ -454,7 +546,8 @@ class elaborator::imp {
return replacer(e);
}
void solve(unsigned num_meta) {
void solve() {
unsigned num_meta = m_metavars.size();
m_add_constraints = false;
while (true) {
solve_core();
@ -493,24 +586,6 @@ class elaborator::imp {
}
}
expr mk_metavars(expr const & e) {
// replace placeholders with fresh metavars
auto proc = [&](expr const & n, unsigned offset) -> expr {
if (is_placeholder(n)) {
return mk_metavar();
} else {
return n;
}
};
auto tracer = [&](expr const & old_e, expr const & new_e) {
if (!is_eqp(new_e, old_e)) {
m_trace[new_e] = old_e;
}
};
replace_fn<decltype(proc), decltype(tracer)> replacer(proc, tracer);
return replacer(e);
}
public:
imp(environment const & env, name_set const * defs):
m_env(env),
@ -519,13 +594,6 @@ public:
m_owner = nullptr;
}
expr mk_metavar() {
unsigned midx = m_metavars.size();
expr r = ::lean::mk_metavar(midx);
m_metavars.push_back(metavar_info());
return r;
}
void clear() {
m_trace.clear();
}
@ -560,12 +628,10 @@ public:
if (has_placeholder(e)) {
m_constraints.clear();
m_metavars.clear();
m_root = mk_metavars(e);
m_owner = &elb;
unsigned num_meta = m_metavars.size();
m_add_constraints = true;
infer(m_root, context());
solve(num_meta);
m_root = process(e, context()).first;
solve();
return instantiate(m_root);
} else {
return e;
@ -607,7 +673,6 @@ public:
};
elaborator::elaborator(environment const & env):m_ptr(new imp(env, nullptr)) {}
elaborator::~elaborator() {}
expr elaborator::mk_metavar() { return m_ptr->mk_metavar(); }
expr elaborator::operator()(expr const & e) { return (*m_ptr)(e, *this); }
expr const & elaborator::get_original(expr const & e) const { return m_ptr->get_original(e); }
void elaborator::set_interrupt(bool flag) { m_ptr->set_interrupt(flag); }

View file

@ -23,8 +23,6 @@ public:
explicit elaborator(environment const & env);
~elaborator();
expr mk_metavar();
expr operator()(expr const & e);
/**
@ -45,8 +43,30 @@ public:
void display(std::ostream & out) const;
format pp(formatter & f, options const & o) const;
};
/** \brief Return true iff \c e is a special constant used to mark application of overloads. */
bool is_overload_marker(expr const & e);
/** \brief Return the overload marker */
expr mk_overload_marker();
/**
\brief Create a choice expression for the given functions.
It is used to mark which functions can be used in a particular application.
The elaborator decides which one should be used based on the type of the arguments.
\pre num_fs >= 2
*/
expr mk_choice(unsigned num_fs, expr const * fs);
/**
\brief Return true iff \c e is an expression created using \c mk_choice.
*/
bool is_choice(expr const & e);
/**
\brief Return the number of alternatives in a choice expression.
We have that <tt>get_num_choices(mk_choice(n, fs)) == n</tt>.
\pre is_choice(e)
*/
unsigned get_num_choices(expr const & e);
/**
\brief Return the (i+1)-th alternative of a choice expression.
\pre is_choice(e)
\pre i < get_num_choices(e)
*/
expr const & get_choice(expr const & e, unsigned i);
}

View file

@ -107,6 +107,7 @@ bool is_subst(expr const & e) {
}
expr mk_lift_fn(unsigned s, unsigned n) {
lean_assert(n > 0);
return mk_constant(name(name(g_lift_prefix, s), n));
}

View file

@ -5,7 +5,7 @@ Error (line: 4, pos: 40) application type mismatch during term elaboration at te
Elaborator state
?M0 := [unassigned]
?M1 := [unassigned]
#0 ≈ lift:0:0 ?M0
#0 ≈ lift:0:2 ?M0
Assumed: myeq
myeq (Π (A : Type) (a : A), A) (λ (A : Type) (a : A), a) (λ (B : Type) (b : B), b)
Bool