feat(kernel): simplify choice_fn, and make its interface closer to the unifier_plugin interface

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-07-04 12:47:33 -07:00
parent b94ce412ae
commit d7cb1952ae
5 changed files with 65 additions and 122 deletions

View file

@ -30,7 +30,7 @@ Author: Leonardo de Moura
namespace lean {
class elaborator {
typedef list<expr> context;
typedef std::vector<constraint> constraints;
typedef std::vector<constraint> constraint_vect;
typedef name_map<expr> tactic_hints;
typedef name_map<expr> mvar2meta;
@ -43,7 +43,7 @@ class elaborator {
context m_ctx;
pos_info_provider * m_pos_provider;
justification m_accumulated; // accumulate justification of eagerly used substitutions
constraints m_constraints;
constraint_vect m_constraints;
tactic_hints m_tactic_hints;
mvar2meta m_mvar2meta;
name_set m_displayed_errors; // set of metavariables that we already reported unsolved/unassigned
@ -84,15 +84,16 @@ class elaborator {
struct choice_elaborator {
elaborator & m_elab;
expr m_mvar;
expr m_choice;
context m_ctx;
substitution m_subst;
unsigned m_idx;
choice_elaborator(elaborator & elab, expr const & c, context const & ctx, substitution const & s):
m_elab(elab), m_choice(c), m_ctx(ctx), m_subst(s), m_idx(0) {
choice_elaborator(elaborator & elab, expr const & mvar, expr const & c, context const & ctx, substitution const & s):
m_elab(elab), m_mvar(mvar), m_choice(c), m_ctx(ctx), m_subst(s), m_idx(0) {
}
optional<a_choice> next() {
optional<constraints> next() {
while (m_idx < get_num_choices(m_choice)) {
expr const & c = get_choice(m_choice, m_idx);
m_idx++;
@ -102,20 +103,21 @@ class elaborator {
justification j = m_elab.m_accumulated;
m_elab.consume_tc_cnstrs();
list<constraint> cs = to_list(m_elab.m_constraints.begin(), m_elab.m_constraints.end());
return optional<a_choice>(r, j, cs);
cs = cons(mk_eq_cnstr(m_mvar, r, j), cs);
return optional<constraints>(cs);
} catch (exception &) {}
}
return optional<a_choice>();
return optional<constraints>();
}
};
lazy_list<a_choice> choose(std::shared_ptr<choice_elaborator> c) {
return mk_lazy_list<a_choice>([=]() {
lazy_list<constraints> choose(std::shared_ptr<choice_elaborator> c) {
return mk_lazy_list<constraints>([=]() {
auto s = c->next();
if (s)
return some(mk_pair(*s, choose(c)));
else
return lazy_list<a_choice>::maybe_pair();
return lazy_list<constraints>::maybe_pair();
});
}
@ -326,11 +328,11 @@ public:
// Possible optimization: try to lookahead and discard some of the alternatives.
expr m = mk_meta(t, e.get_tag());
context ctx = m_ctx;
auto choice_fn = [=](expr const & /* t */, substitution const & s, name_generator const & /* ngen */) {
return choose(std::make_shared<choice_elaborator>(*this, e, ctx, s));
auto fn = [=](expr const & mvar, expr const & /* type */, substitution const & s, name_generator const & /* ngen */) {
return choose(std::make_shared<choice_elaborator>(*this, mvar, e, ctx, s));
};
justification j = mk_justification("none of the overloads is applicable", some_expr(e));
add_cnstr(mk_choice_cnstr(m, choice_fn, false, j));
add_cnstr(mk_choice_cnstr(m, fn, false, j));
return m;
}
@ -404,10 +406,9 @@ public:
expr mk_delayed_coercion(expr const & e, expr const & d_type, expr const & a_type) {
expr a = app_arg(e);
expr m = mk_meta(some_expr(d_type), a.get_tag());
auto choice_fn = [=](expr const & new_d_type, substitution const & /* s */, name_generator const & /* ngen */) {
auto choice_fn = [=](expr const & mvar, expr const & new_d_type, substitution const & /* s */, name_generator const & /* ngen */) {
expr r = apply_coercion(a, a_type, new_d_type);
a_choice c(r, justification(), list<constraint>());
return lazy_list<a_choice>(c);
return lazy_list<constraints>(constraints(mk_eq_cnstr(mvar, r, justification())));
};
justification j = mk_app_justification(m_env, e, d_type, a_type);
add_cnstr(mk_choice_cnstr(m, choice_fn, false, j));

View file

@ -34,22 +34,18 @@ namespace lean {
enum class constraint_kind { Eq, LevelEq, Choice };
class constraint;
typedef list<constraint> constraints;
typedef std::tuple<expr, justification, constraints> a_choice; // a choice produced by the choice_fn
/**
\brief A choice_fn is used to enumerate the possible solutions for a metavariable.
The input arguments are:
- an inferred type
- metavariable that should be inferred
- the metavariable type
- substitution map (metavar -> value)
- name generator
The result is a lazy_list of choices, i.e., tuples containing:
- an expression representing one of the possible solutions
- a justification for it (this is used to accumulate the justification for the substitutions used).
- a list of new constraints (that is, the solution is only valid if the additional constraints can be solved)
The result is a lazy_list of constraints
One application of choice constraints is overloaded notation.
*/
typedef std::function<lazy_list<a_choice>(expr const &, substitution const &, name_generator const &)> choice_fn;
typedef std::function<lazy_list<constraints>(expr const &, expr const &, substitution const &, name_generator const &)> choice_fn;
struct constraint_cell;
class constraint {
@ -76,6 +72,9 @@ public:
constraint_cell * raw() const { return m_ptr; }
};
inline bool operator==(constraint const & c1, constraint const & c2) { return c1.raw() == c2.raw(); }
inline bool operator!=(constraint const & c1, constraint const & c2) { return !(c1 == c2); }
constraint mk_eq_cnstr(expr const & lhs, expr const & rhs, justification const & j);
constraint mk_level_eq_cnstr(level const & lhs, level const & rhs, justification const & j);
constraint mk_choice_cnstr(expr const & m, choice_fn const & fn, bool delayed, justification const & j);

View file

@ -1431,6 +1431,7 @@ static void open_justification(lua_State * L) {
// Constraint
DECL_UDATA(constraint)
DEFINE_LUA_LIST(constraint, push_constraint, to_constraint)
int push_optional_constraint(lua_State * L, optional<constraint> const & c) { return c ? push_constraint(L, *c) : push_nil(L); }
#define CNSTR_PRED(P) static int constraint_ ## P(lua_State * L) { check_num_args(L, 1); return push_boolean(L, P(to_constraint(L, 1))); }
CNSTR_PRED(is_eq_cnstr)
@ -1475,14 +1476,15 @@ static int mk_level_eq_cnstr(lua_State * L) {
static choice_fn to_choice_fn(lua_State * L, int idx) {
luaL_checktype(L, idx, LUA_TFUNCTION); // user-fun
luaref f(L, idx);
return choice_fn([=](expr const & e, substitution const & s, name_generator const & ngen) {
return choice_fn([=](expr const & mvar, expr const & mvar_type, substitution const & s, name_generator const & ngen) {
lua_State * L = f.get_state();
f.push();
push_expr(L, e);
push_expr(L, mvar);
push_expr(L, mvar_type);
push_substitution(L, s);
push_name_generator(L, ngen);
pcall(L, 3, 1, 0);
buffer<a_choice> r;
pcall(L, 4, 1, 0);
buffer<constraints> r;
if (lua_isnil(L, -1)) {
// do nothing
} else if (lua_istable(L, -1)) {
@ -1490,35 +1492,12 @@ static choice_fn to_choice_fn(lua_State * L, int idx) {
// each entry is an alternative
for (int i = 1; i <= num; i++) {
lua_rawgeti(L, -1, i);
if (is_expr(L, -1)) {
r.push_back(a_choice(to_expr(L, -1), justification(), constraints()));
} else if (lua_istable(L, -1) && objlen(L, -1) == 3) {
lua_rawgeti(L, -1, 1);
expr c = to_expr(L, -1);
lua_pop(L, 1);
lua_rawgeti(L, -1, 2);
justification j = to_justification(L, -1);
lua_pop(L, 1);
lua_rawgeti(L, -1, 3);
buffer<constraint> cs;
if (lua_isnil(L, -1)) {
// do nothing
} else if (lua_istable(L, -1)) {
int num_cs = objlen(L, -1);
for (int i = 1; i <= num_cs; i++) {
lua_rawgeti(L, -1, i);
cs.push_back(to_constraint(L, -1));
lua_pop(L, 1);
}
} else {
throw exception("invalid choice function, result must be an array of triples, "
"where the third element of each triple is an array of constraints");
}
lua_pop(L, 1);
r.push_back(a_choice(c, j, to_list(cs.begin(), cs.end())));
} else {
throw exception("invalid choice function, result must be an array of triples");
}
if (is_constraint(L, -1))
r.push_back(constraints(to_constraint(L, -1)));
else if (is_expr(L, -1))
r.push_back(constraints(mk_eq_cnstr(mvar, to_expr(L, -1), justification())));
else
r.push_back(to_list_constraint_ext(L, -1));
lua_pop(L, 1);
}
} else {
@ -1960,6 +1939,7 @@ void open_kernel_module(lua_State * L) {
open_io_state(L);
open_justification(L);
open_constraint(L);
open_list_constraint(L);
open_substitution(L);
open_type_checker(L);
open_inductive(L);

View file

@ -288,19 +288,10 @@ struct unifier_fn {
};
typedef std::vector<std::unique_ptr<case_split>> case_split_stack;
struct plugin_case_split : public case_split {
struct lazy_constraints_case_split : public case_split {
lazy_list<constraints> m_tail;
plugin_case_split(unifier_fn & u, lazy_list<constraints> const & tail):case_split(u), m_tail(tail) {}
virtual bool next(unifier_fn & u) { return u.next_plugin_case_split(*this); }
};
struct choice_case_split : public case_split {
expr m_expr;
justification m_jst;
lazy_list<a_choice> m_tail;
choice_case_split(unifier_fn & u, expr const & expr, justification const & j, lazy_list<a_choice> const & tail):
case_split(u), m_expr(expr), m_jst(j), m_tail(tail) {}
virtual bool next(unifier_fn & u) { return u.next_choice_case_split(*this); }
lazy_constraints_case_split(unifier_fn & u, lazy_list<constraints> const & tail):case_split(u), m_tail(tail) {}
virtual bool next(unifier_fn & u) { return u.next_lazy_constraints_case_split(*this); }
};
struct ho_case_split : public case_split {
@ -775,21 +766,13 @@ struct unifier_fn {
return !in_conflict();
}
bool process_choice_result(expr const & m, a_choice const & r, justification j) {
j = mk_composite1(j, std::get<1>(r));
return
process_constraint(mk_eq_cnstr(m, std::get<0>(r), j)) &&
process_constraints(std::get<2>(r), j);
}
bool next_choice_case_split(choice_case_split & cs) {
bool next_lazy_constraints_case_split(lazy_constraints_case_split & cs) {
auto r = cs.m_tail.pull();
if (r) {
cs.restore_state(*this);
lean_assert(!in_conflict());
cs.m_tail = r->second;
justification a = mk_assumption_justification(cs.m_assumption_idx);
return process_choice_result(cs.m_expr, r->first, mk_composite1(cs.m_jst, a));
return process_constraints(r->first, mk_assumption_justification(cs.m_assumption_idx));
} else {
// update conflict
update_conflict(mk_composite1(*m_conflict, cs.m_failed_justifications));
@ -797,22 +780,16 @@ struct unifier_fn {
}
}
bool process_choice_constraint(constraint const & c) {
lean_assert(is_choice_cnstr(c));
expr const & m = cnstr_expr(c);
choice_fn const & fn = cnstr_choice_fn(c);
auto m_type_jst = m_subst.instantiate_metavars(m_tc.infer(m), nullptr, nullptr);
auto rlist = fn(m_type_jst.first, m_subst, m_ngen.mk_child());
auto r = rlist.pull();
justification j = mk_composite1(c.get_justification(), m_type_jst.second);
bool process_lazy_constraints(lazy_list<constraints> const & l, justification const & j) {
auto r = l.pull();
if (r) {
if (r->second.is_nil()) {
// there is only one alternative
return process_choice_result(m, r->first, j);
return process_constraints(r->first, j);
} else {
justification a = mk_assumption_justification(m_next_assumption_idx);
add_case_split(std::unique_ptr<case_split>(new choice_case_split(*this, m, m_type_jst.second, r->second)));
return process_choice_result(m, r->first, mk_composite1(j, a));
add_case_split(std::unique_ptr<case_split>(new lazy_constraints_case_split(*this, r->second)));
return process_constraints(r->first, mk_composite1(j, a));
}
} else {
set_conflict(j);
@ -820,6 +797,21 @@ struct unifier_fn {
}
}
bool process_plugin_constraint(constraint const & c) {
lean_assert(!is_choice_cnstr(c));
lazy_list<constraints> alts = m_plugin(c, m_ngen.mk_child());
return process_lazy_constraints(alts, c.get_justification());
}
bool process_choice_constraint(constraint const & c) {
lean_assert(is_choice_cnstr(c));
expr const & m = cnstr_expr(c);
choice_fn const & fn = cnstr_choice_fn(c);
auto m_type_jst = m_subst.instantiate_metavars(m_tc.infer(m), nullptr, nullptr);
lazy_list<constraints> alts = fn(m, m_type_jst.first, m_subst, m_ngen.mk_child());
return process_lazy_constraints(alts, mk_composite1(c.get_justification(), m_type_jst.second));
}
/** \brief Return true iff \c e is of the form (elim ... (?m ...)) */
bool is_elim_meta_app(expr const & e) {
if (!is_app(e))
@ -921,35 +913,6 @@ struct unifier_fn {
return process_elim_meta_core(rhs, lhs, j);
}
bool next_plugin_case_split(plugin_case_split & cs) {
auto r = cs.m_tail.pull();
if (r) {
cs.restore_state(*this);
lean_assert(!in_conflict());
cs.m_tail = r->second;
return process_constraints(r->first, mk_assumption_justification(cs.m_assumption_idx));
} else {
// update conflict
update_conflict(mk_composite1(*m_conflict, cs.m_failed_justifications));
return false;
}
}
bool process_plugin_constraint(constraint const & c) {
lean_assert(!is_choice_cnstr(c));
lazy_list<constraints> alts = m_plugin(c, m_ngen.mk_child());
auto r = alts.pull();
if (!r) {
set_conflict(c.get_justification());
return false;
} else {
// create a backtracking point
justification a = mk_assumption_justification(m_next_assumption_idx);
add_case_split(std::unique_ptr<case_split>(new plugin_case_split(*this, r->second)));
return process_constraints(r->first, a);
}
}
bool next_ho_case_split(ho_case_split & cs) {
if (!is_nil(cs.m_tail)) {
cs.restore_state(*this);

View file

@ -29,9 +29,9 @@ function display_solutions(m, ss)
end
cs = { mk_eq_cnstr(m1, f(m2, f(m3, m4))),
mk_choice_cnstr(m2, function(e, s, ngen) return {{a, justification(), {}}, {f(a, a), justification(), {}}} end),
mk_choice_cnstr(m3, function(e, s, ngen) return {{b, justification(), {}}, {f(b, b), justification(), {}}} end),
mk_choice_cnstr(m4, function(e, s, ngen) return {a, b} end)
mk_choice_cnstr(m2, function(m, e, s, ngen) return {{mk_eq_cnstr(m, a)}, {mk_eq_cnstr(m, f(a, a))}} end),
mk_choice_cnstr(m3, function(m, e, s, ngen) return {mk_eq_cnstr(m, b), mk_eq_cnstr(m, f(b, b))} end),
mk_choice_cnstr(m4, function(m, e, s, ngen) return {a, b} end)
}
display_solutions(m1, unify(env, cs, name_generator(), o))