feat(kernel): simplify choice_fn, and make its interface closer to the unifier_plugin interface

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-07-04 12:47:33 -07:00
parent b94ce412ae
commit d7cb1952ae
5 changed files with 65 additions and 122 deletions

View file

@ -30,7 +30,7 @@ Author: Leonardo de Moura
namespace lean { namespace lean {
class elaborator { class elaborator {
typedef list<expr> context; typedef list<expr> context;
typedef std::vector<constraint> constraints; typedef std::vector<constraint> constraint_vect;
typedef name_map<expr> tactic_hints; typedef name_map<expr> tactic_hints;
typedef name_map<expr> mvar2meta; typedef name_map<expr> mvar2meta;
@ -43,7 +43,7 @@ class elaborator {
context m_ctx; context m_ctx;
pos_info_provider * m_pos_provider; pos_info_provider * m_pos_provider;
justification m_accumulated; // accumulate justification of eagerly used substitutions justification m_accumulated; // accumulate justification of eagerly used substitutions
constraints m_constraints; constraint_vect m_constraints;
tactic_hints m_tactic_hints; tactic_hints m_tactic_hints;
mvar2meta m_mvar2meta; mvar2meta m_mvar2meta;
name_set m_displayed_errors; // set of metavariables that we already reported unsolved/unassigned name_set m_displayed_errors; // set of metavariables that we already reported unsolved/unassigned
@ -84,15 +84,16 @@ class elaborator {
struct choice_elaborator { struct choice_elaborator {
elaborator & m_elab; elaborator & m_elab;
expr m_mvar;
expr m_choice; expr m_choice;
context m_ctx; context m_ctx;
substitution m_subst; substitution m_subst;
unsigned m_idx; unsigned m_idx;
choice_elaborator(elaborator & elab, expr const & c, context const & ctx, substitution const & s): choice_elaborator(elaborator & elab, expr const & mvar, expr const & c, context const & ctx, substitution const & s):
m_elab(elab), m_choice(c), m_ctx(ctx), m_subst(s), m_idx(0) { m_elab(elab), m_mvar(mvar), m_choice(c), m_ctx(ctx), m_subst(s), m_idx(0) {
} }
optional<a_choice> next() { optional<constraints> next() {
while (m_idx < get_num_choices(m_choice)) { while (m_idx < get_num_choices(m_choice)) {
expr const & c = get_choice(m_choice, m_idx); expr const & c = get_choice(m_choice, m_idx);
m_idx++; m_idx++;
@ -102,20 +103,21 @@ class elaborator {
justification j = m_elab.m_accumulated; justification j = m_elab.m_accumulated;
m_elab.consume_tc_cnstrs(); m_elab.consume_tc_cnstrs();
list<constraint> cs = to_list(m_elab.m_constraints.begin(), m_elab.m_constraints.end()); list<constraint> cs = to_list(m_elab.m_constraints.begin(), m_elab.m_constraints.end());
return optional<a_choice>(r, j, cs); cs = cons(mk_eq_cnstr(m_mvar, r, j), cs);
return optional<constraints>(cs);
} catch (exception &) {} } catch (exception &) {}
} }
return optional<a_choice>(); return optional<constraints>();
} }
}; };
lazy_list<a_choice> choose(std::shared_ptr<choice_elaborator> c) { lazy_list<constraints> choose(std::shared_ptr<choice_elaborator> c) {
return mk_lazy_list<a_choice>([=]() { return mk_lazy_list<constraints>([=]() {
auto s = c->next(); auto s = c->next();
if (s) if (s)
return some(mk_pair(*s, choose(c))); return some(mk_pair(*s, choose(c)));
else else
return lazy_list<a_choice>::maybe_pair(); return lazy_list<constraints>::maybe_pair();
}); });
} }
@ -326,11 +328,11 @@ public:
// Possible optimization: try to lookahead and discard some of the alternatives. // Possible optimization: try to lookahead and discard some of the alternatives.
expr m = mk_meta(t, e.get_tag()); expr m = mk_meta(t, e.get_tag());
context ctx = m_ctx; context ctx = m_ctx;
auto choice_fn = [=](expr const & /* t */, substitution const & s, name_generator const & /* ngen */) { auto fn = [=](expr const & mvar, expr const & /* type */, substitution const & s, name_generator const & /* ngen */) {
return choose(std::make_shared<choice_elaborator>(*this, e, ctx, s)); return choose(std::make_shared<choice_elaborator>(*this, mvar, e, ctx, s));
}; };
justification j = mk_justification("none of the overloads is applicable", some_expr(e)); justification j = mk_justification("none of the overloads is applicable", some_expr(e));
add_cnstr(mk_choice_cnstr(m, choice_fn, false, j)); add_cnstr(mk_choice_cnstr(m, fn, false, j));
return m; return m;
} }
@ -404,10 +406,9 @@ public:
expr mk_delayed_coercion(expr const & e, expr const & d_type, expr const & a_type) { expr mk_delayed_coercion(expr const & e, expr const & d_type, expr const & a_type) {
expr a = app_arg(e); expr a = app_arg(e);
expr m = mk_meta(some_expr(d_type), a.get_tag()); expr m = mk_meta(some_expr(d_type), a.get_tag());
auto choice_fn = [=](expr const & new_d_type, substitution const & /* s */, name_generator const & /* ngen */) { auto choice_fn = [=](expr const & mvar, expr const & new_d_type, substitution const & /* s */, name_generator const & /* ngen */) {
expr r = apply_coercion(a, a_type, new_d_type); expr r = apply_coercion(a, a_type, new_d_type);
a_choice c(r, justification(), list<constraint>()); return lazy_list<constraints>(constraints(mk_eq_cnstr(mvar, r, justification())));
return lazy_list<a_choice>(c);
}; };
justification j = mk_app_justification(m_env, e, d_type, a_type); justification j = mk_app_justification(m_env, e, d_type, a_type);
add_cnstr(mk_choice_cnstr(m, choice_fn, false, j)); add_cnstr(mk_choice_cnstr(m, choice_fn, false, j));

View file

@ -34,22 +34,18 @@ namespace lean {
enum class constraint_kind { Eq, LevelEq, Choice }; enum class constraint_kind { Eq, LevelEq, Choice };
class constraint; class constraint;
typedef list<constraint> constraints; typedef list<constraint> constraints;
typedef std::tuple<expr, justification, constraints> a_choice; // a choice produced by the choice_fn
/** /**
\brief A choice_fn is used to enumerate the possible solutions for a metavariable. \brief A choice_fn is used to enumerate the possible solutions for a metavariable.
The input arguments are: The input arguments are:
- an inferred type - metavariable that should be inferred
- the metavariable type
- substitution map (metavar -> value) - substitution map (metavar -> value)
- name generator - name generator
The result is a lazy_list of constraints
The result is a lazy_list of choices, i.e., tuples containing:
- an expression representing one of the possible solutions
- a justification for it (this is used to accumulate the justification for the substitutions used).
- a list of new constraints (that is, the solution is only valid if the additional constraints can be solved)
One application of choice constraints is overloaded notation. One application of choice constraints is overloaded notation.
*/ */
typedef std::function<lazy_list<a_choice>(expr const &, substitution const &, name_generator const &)> choice_fn; typedef std::function<lazy_list<constraints>(expr const &, expr const &, substitution const &, name_generator const &)> choice_fn;
struct constraint_cell; struct constraint_cell;
class constraint { class constraint {
@ -76,6 +72,9 @@ public:
constraint_cell * raw() const { return m_ptr; } constraint_cell * raw() const { return m_ptr; }
}; };
inline bool operator==(constraint const & c1, constraint const & c2) { return c1.raw() == c2.raw(); }
inline bool operator!=(constraint const & c1, constraint const & c2) { return !(c1 == c2); }
constraint mk_eq_cnstr(expr const & lhs, expr const & rhs, justification const & j); constraint mk_eq_cnstr(expr const & lhs, expr const & rhs, justification const & j);
constraint mk_level_eq_cnstr(level const & lhs, level const & rhs, justification const & j); constraint mk_level_eq_cnstr(level const & lhs, level const & rhs, justification const & j);
constraint mk_choice_cnstr(expr const & m, choice_fn const & fn, bool delayed, justification const & j); constraint mk_choice_cnstr(expr const & m, choice_fn const & fn, bool delayed, justification const & j);

View file

@ -1431,6 +1431,7 @@ static void open_justification(lua_State * L) {
// Constraint // Constraint
DECL_UDATA(constraint) DECL_UDATA(constraint)
DEFINE_LUA_LIST(constraint, push_constraint, to_constraint)
int push_optional_constraint(lua_State * L, optional<constraint> const & c) { return c ? push_constraint(L, *c) : push_nil(L); } int push_optional_constraint(lua_State * L, optional<constraint> const & c) { return c ? push_constraint(L, *c) : push_nil(L); }
#define CNSTR_PRED(P) static int constraint_ ## P(lua_State * L) { check_num_args(L, 1); return push_boolean(L, P(to_constraint(L, 1))); } #define CNSTR_PRED(P) static int constraint_ ## P(lua_State * L) { check_num_args(L, 1); return push_boolean(L, P(to_constraint(L, 1))); }
CNSTR_PRED(is_eq_cnstr) CNSTR_PRED(is_eq_cnstr)
@ -1475,14 +1476,15 @@ static int mk_level_eq_cnstr(lua_State * L) {
static choice_fn to_choice_fn(lua_State * L, int idx) { static choice_fn to_choice_fn(lua_State * L, int idx) {
luaL_checktype(L, idx, LUA_TFUNCTION); // user-fun luaL_checktype(L, idx, LUA_TFUNCTION); // user-fun
luaref f(L, idx); luaref f(L, idx);
return choice_fn([=](expr const & e, substitution const & s, name_generator const & ngen) { return choice_fn([=](expr const & mvar, expr const & mvar_type, substitution const & s, name_generator const & ngen) {
lua_State * L = f.get_state(); lua_State * L = f.get_state();
f.push(); f.push();
push_expr(L, e); push_expr(L, mvar);
push_expr(L, mvar_type);
push_substitution(L, s); push_substitution(L, s);
push_name_generator(L, ngen); push_name_generator(L, ngen);
pcall(L, 3, 1, 0); pcall(L, 4, 1, 0);
buffer<a_choice> r; buffer<constraints> r;
if (lua_isnil(L, -1)) { if (lua_isnil(L, -1)) {
// do nothing // do nothing
} else if (lua_istable(L, -1)) { } else if (lua_istable(L, -1)) {
@ -1490,35 +1492,12 @@ static choice_fn to_choice_fn(lua_State * L, int idx) {
// each entry is an alternative // each entry is an alternative
for (int i = 1; i <= num; i++) { for (int i = 1; i <= num; i++) {
lua_rawgeti(L, -1, i); lua_rawgeti(L, -1, i);
if (is_expr(L, -1)) { if (is_constraint(L, -1))
r.push_back(a_choice(to_expr(L, -1), justification(), constraints())); r.push_back(constraints(to_constraint(L, -1)));
} else if (lua_istable(L, -1) && objlen(L, -1) == 3) { else if (is_expr(L, -1))
lua_rawgeti(L, -1, 1); r.push_back(constraints(mk_eq_cnstr(mvar, to_expr(L, -1), justification())));
expr c = to_expr(L, -1); else
lua_pop(L, 1); r.push_back(to_list_constraint_ext(L, -1));
lua_rawgeti(L, -1, 2);
justification j = to_justification(L, -1);
lua_pop(L, 1);
lua_rawgeti(L, -1, 3);
buffer<constraint> cs;
if (lua_isnil(L, -1)) {
// do nothing
} else if (lua_istable(L, -1)) {
int num_cs = objlen(L, -1);
for (int i = 1; i <= num_cs; i++) {
lua_rawgeti(L, -1, i);
cs.push_back(to_constraint(L, -1));
lua_pop(L, 1);
}
} else {
throw exception("invalid choice function, result must be an array of triples, "
"where the third element of each triple is an array of constraints");
}
lua_pop(L, 1);
r.push_back(a_choice(c, j, to_list(cs.begin(), cs.end())));
} else {
throw exception("invalid choice function, result must be an array of triples");
}
lua_pop(L, 1); lua_pop(L, 1);
} }
} else { } else {
@ -1960,6 +1939,7 @@ void open_kernel_module(lua_State * L) {
open_io_state(L); open_io_state(L);
open_justification(L); open_justification(L);
open_constraint(L); open_constraint(L);
open_list_constraint(L);
open_substitution(L); open_substitution(L);
open_type_checker(L); open_type_checker(L);
open_inductive(L); open_inductive(L);

View file

@ -288,19 +288,10 @@ struct unifier_fn {
}; };
typedef std::vector<std::unique_ptr<case_split>> case_split_stack; typedef std::vector<std::unique_ptr<case_split>> case_split_stack;
struct plugin_case_split : public case_split { struct lazy_constraints_case_split : public case_split {
lazy_list<constraints> m_tail; lazy_list<constraints> m_tail;
plugin_case_split(unifier_fn & u, lazy_list<constraints> const & tail):case_split(u), m_tail(tail) {} lazy_constraints_case_split(unifier_fn & u, lazy_list<constraints> const & tail):case_split(u), m_tail(tail) {}
virtual bool next(unifier_fn & u) { return u.next_plugin_case_split(*this); } virtual bool next(unifier_fn & u) { return u.next_lazy_constraints_case_split(*this); }
};
struct choice_case_split : public case_split {
expr m_expr;
justification m_jst;
lazy_list<a_choice> m_tail;
choice_case_split(unifier_fn & u, expr const & expr, justification const & j, lazy_list<a_choice> const & tail):
case_split(u), m_expr(expr), m_jst(j), m_tail(tail) {}
virtual bool next(unifier_fn & u) { return u.next_choice_case_split(*this); }
}; };
struct ho_case_split : public case_split { struct ho_case_split : public case_split {
@ -775,21 +766,13 @@ struct unifier_fn {
return !in_conflict(); return !in_conflict();
} }
bool process_choice_result(expr const & m, a_choice const & r, justification j) { bool next_lazy_constraints_case_split(lazy_constraints_case_split & cs) {
j = mk_composite1(j, std::get<1>(r));
return
process_constraint(mk_eq_cnstr(m, std::get<0>(r), j)) &&
process_constraints(std::get<2>(r), j);
}
bool next_choice_case_split(choice_case_split & cs) {
auto r = cs.m_tail.pull(); auto r = cs.m_tail.pull();
if (r) { if (r) {
cs.restore_state(*this); cs.restore_state(*this);
lean_assert(!in_conflict()); lean_assert(!in_conflict());
cs.m_tail = r->second; cs.m_tail = r->second;
justification a = mk_assumption_justification(cs.m_assumption_idx); return process_constraints(r->first, mk_assumption_justification(cs.m_assumption_idx));
return process_choice_result(cs.m_expr, r->first, mk_composite1(cs.m_jst, a));
} else { } else {
// update conflict // update conflict
update_conflict(mk_composite1(*m_conflict, cs.m_failed_justifications)); update_conflict(mk_composite1(*m_conflict, cs.m_failed_justifications));
@ -797,22 +780,16 @@ struct unifier_fn {
} }
} }
bool process_choice_constraint(constraint const & c) { bool process_lazy_constraints(lazy_list<constraints> const & l, justification const & j) {
lean_assert(is_choice_cnstr(c)); auto r = l.pull();
expr const & m = cnstr_expr(c);
choice_fn const & fn = cnstr_choice_fn(c);
auto m_type_jst = m_subst.instantiate_metavars(m_tc.infer(m), nullptr, nullptr);
auto rlist = fn(m_type_jst.first, m_subst, m_ngen.mk_child());
auto r = rlist.pull();
justification j = mk_composite1(c.get_justification(), m_type_jst.second);
if (r) { if (r) {
if (r->second.is_nil()) { if (r->second.is_nil()) {
// there is only one alternative // there is only one alternative
return process_choice_result(m, r->first, j); return process_constraints(r->first, j);
} else { } else {
justification a = mk_assumption_justification(m_next_assumption_idx); justification a = mk_assumption_justification(m_next_assumption_idx);
add_case_split(std::unique_ptr<case_split>(new choice_case_split(*this, m, m_type_jst.second, r->second))); add_case_split(std::unique_ptr<case_split>(new lazy_constraints_case_split(*this, r->second)));
return process_choice_result(m, r->first, mk_composite1(j, a)); return process_constraints(r->first, mk_composite1(j, a));
} }
} else { } else {
set_conflict(j); set_conflict(j);
@ -820,6 +797,21 @@ struct unifier_fn {
} }
} }
bool process_plugin_constraint(constraint const & c) {
lean_assert(!is_choice_cnstr(c));
lazy_list<constraints> alts = m_plugin(c, m_ngen.mk_child());
return process_lazy_constraints(alts, c.get_justification());
}
bool process_choice_constraint(constraint const & c) {
lean_assert(is_choice_cnstr(c));
expr const & m = cnstr_expr(c);
choice_fn const & fn = cnstr_choice_fn(c);
auto m_type_jst = m_subst.instantiate_metavars(m_tc.infer(m), nullptr, nullptr);
lazy_list<constraints> alts = fn(m, m_type_jst.first, m_subst, m_ngen.mk_child());
return process_lazy_constraints(alts, mk_composite1(c.get_justification(), m_type_jst.second));
}
/** \brief Return true iff \c e is of the form (elim ... (?m ...)) */ /** \brief Return true iff \c e is of the form (elim ... (?m ...)) */
bool is_elim_meta_app(expr const & e) { bool is_elim_meta_app(expr const & e) {
if (!is_app(e)) if (!is_app(e))
@ -921,35 +913,6 @@ struct unifier_fn {
return process_elim_meta_core(rhs, lhs, j); return process_elim_meta_core(rhs, lhs, j);
} }
bool next_plugin_case_split(plugin_case_split & cs) {
auto r = cs.m_tail.pull();
if (r) {
cs.restore_state(*this);
lean_assert(!in_conflict());
cs.m_tail = r->second;
return process_constraints(r->first, mk_assumption_justification(cs.m_assumption_idx));
} else {
// update conflict
update_conflict(mk_composite1(*m_conflict, cs.m_failed_justifications));
return false;
}
}
bool process_plugin_constraint(constraint const & c) {
lean_assert(!is_choice_cnstr(c));
lazy_list<constraints> alts = m_plugin(c, m_ngen.mk_child());
auto r = alts.pull();
if (!r) {
set_conflict(c.get_justification());
return false;
} else {
// create a backtracking point
justification a = mk_assumption_justification(m_next_assumption_idx);
add_case_split(std::unique_ptr<case_split>(new plugin_case_split(*this, r->second)));
return process_constraints(r->first, a);
}
}
bool next_ho_case_split(ho_case_split & cs) { bool next_ho_case_split(ho_case_split & cs) {
if (!is_nil(cs.m_tail)) { if (!is_nil(cs.m_tail)) {
cs.restore_state(*this); cs.restore_state(*this);

View file

@ -29,9 +29,9 @@ function display_solutions(m, ss)
end end
cs = { mk_eq_cnstr(m1, f(m2, f(m3, m4))), cs = { mk_eq_cnstr(m1, f(m2, f(m3, m4))),
mk_choice_cnstr(m2, function(e, s, ngen) return {{a, justification(), {}}, {f(a, a), justification(), {}}} end), mk_choice_cnstr(m2, function(m, e, s, ngen) return {{mk_eq_cnstr(m, a)}, {mk_eq_cnstr(m, f(a, a))}} end),
mk_choice_cnstr(m3, function(e, s, ngen) return {{b, justification(), {}}, {f(b, b), justification(), {}}} end), mk_choice_cnstr(m3, function(m, e, s, ngen) return {mk_eq_cnstr(m, b), mk_eq_cnstr(m, f(b, b))} end),
mk_choice_cnstr(m4, function(e, s, ngen) return {a, b} end) mk_choice_cnstr(m4, function(m, e, s, ngen) return {a, b} end)
} }
display_solutions(m1, unify(env, cs, name_generator(), o)) display_solutions(m1, unify(env, cs, name_generator(), o))