refactor(library/unifier): add option m_discard too unifier, if m_discard == false, then unsolved flex-flex constraints are returned, the unifier also does not apply "last resource" techniques that may miss many solutions.

This commit is contained in:
Leonardo de Moura 2014-09-11 14:02:17 -07:00
parent 935ba35292
commit 03902d4b45
4 changed files with 101 additions and 63 deletions

View file

@ -91,6 +91,7 @@ class elaborator : public coercion_info_manager {
// we set is to true whenever we find no_info annotation. // we set is to true whenever we find no_info annotation.
bool m_no_info; bool m_no_info;
info_manager m_pre_info_data; info_manager m_pre_info_data;
unifier_config m_unifier_config;
// Auxiliary object to "saving" elaborator state // Auxiliary object to "saving" elaborator state
struct saved_state { struct saved_state {
@ -174,7 +175,8 @@ public:
m_env(env), m_env(env),
m_ngen(ngen), m_ngen(ngen),
m_context(m_ngen.next(), ctx), m_context(m_ngen.next(), ctx),
m_full_context(m_ngen.next(), ctx) { m_full_context(m_ngen.next(), ctx),
m_unifier_config(env.m_ios.get_options(), true /* use exceptions */, true /* discard */) {
m_relax_main_opaque = false; m_relax_main_opaque = false;
m_no_info = false; m_no_info = false;
m_tc[0] = mk_type_checker_with_hints(env.m_env, m_ngen.mk_child(), false); m_tc[0] = mk_type_checker_with_hints(env.m_env, m_ngen.mk_child(), false);
@ -755,10 +757,10 @@ public:
return r.first; return r.first;
} }
lazy_list<substitution> solve(constraint_seq const & cs) { unify_result_seq solve(constraint_seq const & cs) {
buffer<constraint> tmp; buffer<constraint> tmp;
cs.linearize(tmp); cs.linearize(tmp);
return unify(env(), tmp.size(), tmp.data(), m_ngen.mk_child(), unifier_config(ios().get_options(), true)); return unify(env(), tmp.size(), tmp.data(), m_ngen.mk_child(), m_unifier_config);
} }
void display_unsolved_proof_state(expr const & mvar, proof_state const & ps, char const * msg) { void display_unsolved_proof_state(expr const & mvar, proof_state const & ps, char const * msg) {
@ -927,7 +929,7 @@ public:
r = ensure_type(r, cs); r = ensure_type(r, cs);
auto p = solve(cs).pull(); auto p = solve(cs).pull();
lean_assert(p); lean_assert(p);
substitution s = p->first; substitution s = p->first.first;
auto result = apply(s, r); auto result = apply(s, r);
copy_info_to_manager(s); copy_info_to_manager(s);
return result; return result;
@ -951,7 +953,7 @@ public:
constraint_seq cs = t_cs + r_v_cs.second + v_cs; constraint_seq cs = t_cs + r_v_cs.second + v_cs;
auto p = solve(cs).pull(); auto p = solve(cs).pull();
lean_assert(p); lean_assert(p);
substitution s = p->first; substitution s = p->first.first;
name_set univ_params = collect_univ_params(r_v, collect_univ_params(r_t)); name_set univ_params = collect_univ_params(r_v, collect_univ_params(r_t));
buffer<name> new_params; buffer<name> new_params;
expr new_r_t = apply(s, r_t, univ_params, new_params); expr new_r_t = apply(s, r_t, univ_params, new_params);

View file

@ -110,9 +110,11 @@ proof_state_seq apply_tactic_core(environment const & env, io_state const & ios,
} }
} }
list<expr> meta_lst = to_list(metas.begin(), metas.end()); list<expr> meta_lst = to_list(metas.begin(), metas.end());
lazy_list<substitution> substs = unify(env, t, e_t, ngen.mk_child(), relax_main_opaque, s.get_subst(), unify_result_seq rseq = unify(env, t, e_t, ngen.mk_child(), relax_main_opaque, s.get_subst(),
unifier_config(ios.get_options())); unifier_config(ios.get_options()));
return map2<proof_state>(substs, [=](substitution const & subst) -> proof_state { return map2<proof_state>(rseq, [=](pair<substitution, constraints> const & p) -> proof_state {
substitution const & subst = p.first;
// TODO(Leo): save postponed constraints
name_generator new_ngen(ngen); name_generator new_ngen(ngen);
type_checker tc(env, new_ngen.mk_child()); type_checker tc(env, new_ngen.mk_child());
substitution new_subst = subst; substitution new_subst = subst;

View file

@ -57,18 +57,20 @@ bool get_unifier_expensive_classes(options const & opts) {
return opts.get_bool(g_unifier_expensive_classes, LEAN_DEFAULT_UNIFIER_EXPENSIVE_CLASSES); return opts.get_bool(g_unifier_expensive_classes, LEAN_DEFAULT_UNIFIER_EXPENSIVE_CLASSES);
} }
unifier_config::unifier_config(bool use_exceptions): unifier_config::unifier_config(bool use_exceptions, bool discard):
m_use_exceptions(use_exceptions), m_use_exceptions(use_exceptions),
m_max_steps(LEAN_DEFAULT_UNIFIER_MAX_STEPS), m_max_steps(LEAN_DEFAULT_UNIFIER_MAX_STEPS),
m_computation(LEAN_DEFAULT_UNIFIER_COMPUTATION), m_computation(LEAN_DEFAULT_UNIFIER_COMPUTATION),
m_expensive_classes(LEAN_DEFAULT_UNIFIER_EXPENSIVE_CLASSES) { m_expensive_classes(LEAN_DEFAULT_UNIFIER_EXPENSIVE_CLASSES),
m_discard(discard) {
} }
unifier_config::unifier_config(options const & o, bool use_exceptions): unifier_config::unifier_config(options const & o, bool use_exceptions, bool discard):
m_use_exceptions(use_exceptions), m_use_exceptions(use_exceptions),
m_max_steps(get_unifier_max_steps(o)), m_max_steps(get_unifier_max_steps(o)),
m_computation(get_unifier_computation(o)), m_computation(get_unifier_computation(o)),
m_expensive_classes(get_unifier_expensive_classes(o)) { m_expensive_classes(get_unifier_expensive_classes(o)),
m_discard(discard) {
} }
// If \c e is a metavariable ?m or a term of the form (?m l_1 ... l_n) where // If \c e is a metavariable ?m or a term of the form (?m l_1 ... l_n) where
@ -289,6 +291,7 @@ struct unifier_fn {
environment m_env; environment m_env;
name_generator m_ngen; name_generator m_ngen;
substitution m_subst; substitution m_subst;
constraints m_postponed; // constraints that will not be solved
owned_map m_owned_map; // mapping from metavariable name m to delay factor of the choice constraint that owns m owned_map m_owned_map; // mapping from metavariable name m to delay factor of the choice constraint that owns m
unifier_plugin m_plugin; unifier_plugin m_plugin;
type_checker_ptr m_tc[2]; type_checker_ptr m_tc[2];
@ -334,6 +337,7 @@ struct unifier_fn {
justification m_failed_justifications; // justifications for failed branches justification m_failed_justifications; // justifications for failed branches
// snapshot of unifier's state // snapshot of unifier's state
substitution m_subst; substitution m_subst;
constraints m_postponed;
cnstr_set m_cnstrs; cnstr_set m_cnstrs;
name_to_cnstrs m_mvar_occs; name_to_cnstrs m_mvar_occs;
owned_map m_owned_map; owned_map m_owned_map;
@ -341,7 +345,8 @@ struct unifier_fn {
/** \brief Save unifier's state */ /** \brief Save unifier's state */
case_split(unifier_fn & u, justification const & j): case_split(unifier_fn & u, justification const & j):
m_assumption_idx(u.m_next_assumption_idx), m_jst(j), m_subst(u.m_subst), m_cnstrs(u.m_cnstrs), m_assumption_idx(u.m_next_assumption_idx), m_jst(j), m_subst(u.m_subst),
m_postponed(u.m_postponed), m_cnstrs(u.m_cnstrs),
m_mvar_occs(u.m_mvar_occs), m_owned_map(u.m_owned_map), m_pattern(u.m_pattern) { m_mvar_occs(u.m_mvar_occs), m_owned_map(u.m_owned_map), m_pattern(u.m_pattern) {
u.m_next_assumption_idx++; u.m_next_assumption_idx++;
} }
@ -350,6 +355,7 @@ struct unifier_fn {
void restore_state(unifier_fn & u) { void restore_state(unifier_fn & u) {
lean_assert(u.in_conflict()); lean_assert(u.in_conflict());
u.m_subst = m_subst; u.m_subst = m_subst;
u.m_postponed = m_postponed;
u.m_cnstrs = m_cnstrs; u.m_cnstrs = m_cnstrs;
u.m_mvar_occs = m_mvar_occs; u.m_mvar_occs = m_mvar_occs;
u.m_owned_map = m_owned_map; u.m_owned_map = m_owned_map;
@ -1030,14 +1036,6 @@ struct unifier_fn {
return false; return false;
} }
optional<substitution> failure() {
lean_assert(in_conflict());
if (m_config.m_use_exceptions)
throw unifier_exception(*m_conflict, m_subst);
else
return optional<substitution>();
}
bool next_lazy_constraints_case_split(lazy_constraints_case_split & cs) { bool next_lazy_constraints_case_split(lazy_constraints_case_split & cs) {
auto r = cs.m_tail.pull(); auto r = cs.m_tail.pull();
if (r) { if (r) {
@ -1774,28 +1772,40 @@ struct unifier_fn {
return process_flex_rigid(rhs, lhs, c.get_justification(), relax); return process_flex_rigid(rhs, lhs, c.get_justification(), relax);
} }
void discard(constraint const & c) {
if (!m_config.m_discard)
m_postponed = cons(c, m_postponed);
}
bool process_flex_flex(constraint const & c) { bool process_flex_flex(constraint const & c) {
expr const & lhs = cnstr_lhs_expr(c); expr const & lhs = cnstr_lhs_expr(c);
expr const & rhs = cnstr_rhs_expr(c); expr const & rhs = cnstr_rhs_expr(c);
// We ignore almost all flex-flex constraints. // We ignore almost all flex-flex constraints.
// We just handle flex_flex "first-order" case // We just handle flex_flex "first-order" case
// ?M_1 l_1 ... l_k =?= ?M_2 l_1 ... l_k // ?M_1 l_1 ... l_k =?= ?M_2 l_1 ... l_k
if (!is_simple_meta(lhs) || !is_simple_meta(rhs)) if (!is_simple_meta(lhs) || !is_simple_meta(rhs)) {
discard(c);
return true; return true;
}
buffer<expr> lhs_args, rhs_args; buffer<expr> lhs_args, rhs_args;
expr ml = get_app_args(lhs, lhs_args); expr ml = get_app_args(lhs, lhs_args);
expr mr = get_app_args(rhs, rhs_args); expr mr = get_app_args(rhs, rhs_args);
if (ml == mr || lhs_args.size() != rhs_args.size()) if (ml == mr || lhs_args.size() != rhs_args.size()) {
discard(c);
return true; return true;
}
lean_assert(!m_subst.is_assigned(ml)); lean_assert(!m_subst.is_assigned(ml));
lean_assert(!m_subst.is_assigned(mr)); lean_assert(!m_subst.is_assigned(mr));
unsigned i = 0; unsigned i = 0;
for (; i < lhs_args.size(); i++) for (; i < lhs_args.size(); i++)
if (mlocal_name(lhs_args[i]) != mlocal_name(rhs_args[i])) if (mlocal_name(lhs_args[i]) != mlocal_name(rhs_args[i]))
break; break;
if (i == lhs_args.size()) if (i == lhs_args.size()) {
return assign(ml, mr, c.get_justification(), relax_main_opaque(c), lhs, rhs); return assign(ml, mr, c.get_justification(), relax_main_opaque(c), lhs, rhs);
return true; } else {
discard(c);
return true;
}
} }
/** /**
@ -1893,9 +1903,16 @@ struct unifier_fn {
add_cnstr(c, cnstr_group::FlexFlex); add_cnstr(c, cnstr_group::FlexFlex);
return true; return true;
} }
st = try_merge_max(c); if (m_config.m_discard) {
if (st != Continue) return st == Solved; // try_merge_max is too imprecise, and is only used if we are discarding
return process_plugin_constraint(c); // constraints that cannot be solved.
st = try_merge_max(c);
if (st != Continue) return st == Solved;
return process_plugin_constraint(c);
} else {
discard(c);
return true;
}
} else { } else {
lean_assert(is_eq_cnstr(c)); lean_assert(is_eq_cnstr(c));
if (is_delta_cnstr(c)) { if (is_delta_cnstr(c)) {
@ -1924,8 +1941,18 @@ struct unifier_fn {
return !in_conflict() || !m_case_splits.empty(); return !in_conflict() || !m_case_splits.empty();
} }
typedef optional<pair<substitution, constraints>> next_result;
next_result failure() {
lean_assert(in_conflict());
if (m_config.m_use_exceptions)
throw unifier_exception(*m_conflict, m_subst);
else
return next_result();
}
/** \brief Produce the next solution */ /** \brief Produce the next solution */
optional<substitution> next() { next_result next() {
if (!more_solutions()) if (!more_solutions())
return failure(); return failure();
if (!m_first && !m_case_splits.empty()) { if (!m_first && !m_case_splits.empty()) {
@ -1940,8 +1967,9 @@ struct unifier_fn {
} else { } else {
// This is not the first run, and there are no case-splits. // This is not the first run, and there are no case-splits.
// We don't throw an exception since there are no more solutions. // We don't throw an exception since there are no more solutions.
return optional<substitution>(); return next_result();
} }
while (true) { while (true) {
if (!in_conflict()) { if (!in_conflict()) {
if (m_cnstrs.empty()) if (m_cnstrs.empty())
@ -1955,39 +1983,39 @@ struct unifier_fn {
lean_assert(m_cnstrs.empty()); lean_assert(m_cnstrs.empty());
substitution s = m_subst; substitution s = m_subst;
s.forget_justifications(); s.forget_justifications();
return optional<substitution>(s); return next_result(mk_pair(s, m_postponed));
} }
}; };
lazy_list<substitution> unify(std::shared_ptr<unifier_fn> u) { unify_result_seq unify(std::shared_ptr<unifier_fn> u) {
if (!u->more_solutions()) { if (!u->more_solutions()) {
u->failure(); // make sure exception is thrown if u->m_use_exception is true u->failure(); // make sure exception is thrown if u->m_use_exception is true
return lazy_list<substitution>(); return unify_result_seq();
} else { } else {
return mk_lazy_list<substitution>([=]() { return mk_lazy_list<pair<substitution, constraints>>([=]() {
auto s = u->next(); auto s = u->next();
if (s) if (s)
return some(mk_pair(*s, unify(u))); return some(mk_pair(*s, unify(u)));
else else
return lazy_list<substitution>::maybe_pair(); return unify_result_seq::maybe_pair();
}); });
} }
} }
lazy_list<substitution> unify(environment const & env, unsigned num_cs, constraint const * cs, name_generator const & ngen, unify_result_seq unify(environment const & env, unsigned num_cs, constraint const * cs, name_generator const & ngen,
unifier_config const & cfg) { unifier_config const & cfg) {
return unify(std::make_shared<unifier_fn>(env, num_cs, cs, ngen, substitution(), cfg)); return unify(std::make_shared<unifier_fn>(env, num_cs, cs, ngen, substitution(), cfg));
} }
lazy_list<substitution> unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen, unify_result_seq unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen,
bool relax, substitution const & s, unifier_config const & cfg) { bool relax, substitution const & s, unifier_config const & cfg) {
substitution new_s = s; substitution new_s = s;
expr _lhs = new_s.instantiate(lhs); expr _lhs = new_s.instantiate(lhs);
expr _rhs = new_s.instantiate(rhs); expr _rhs = new_s.instantiate(rhs);
auto u = std::make_shared<unifier_fn>(env, 0, nullptr, ngen, new_s, cfg); auto u = std::make_shared<unifier_fn>(env, 0, nullptr, ngen, new_s, cfg);
constraint_seq cs; constraint_seq cs;
if (!u->m_tc[relax]->is_def_eq(_lhs, _rhs, justification(), cs) || !u->process_constraints(cs)) { if (!u->m_tc[relax]->is_def_eq(_lhs, _rhs, justification(), cs) || !u->process_constraints(cs)) {
return lazy_list<substitution>(); return unify_result_seq();
} else { } else {
return unify(u); return unify(u);
} }
@ -2009,31 +2037,31 @@ static int unify_simple(lua_State * L) {
return push_integer(L, static_cast<unsigned>(r)); return push_integer(L, static_cast<unsigned>(r));
} }
typedef lazy_list<substitution> substitution_seq; DECL_UDATA(unify_result_seq)
DECL_UDATA(substitution_seq)
static const struct luaL_Reg substitution_seq_m[] = { static const struct luaL_Reg unify_result_seq_m[] = {
{"__gc", substitution_seq_gc}, {"__gc", unify_result_seq_gc},
{0, 0} {0, 0}
}; };
static int substitution_seq_next(lua_State * L) { static int unify_result_seq_next(lua_State * L) {
substitution_seq seq = to_substitution_seq(L, lua_upvalueindex(1)); unify_result_seq seq = to_unify_result_seq(L, lua_upvalueindex(1));
substitution_seq::maybe_pair p; unify_result_seq::maybe_pair p;
p = seq.pull(); p = seq.pull();
if (p) { if (p) {
push_substitution_seq(L, p->second); push_unify_result_seq(L, p->second);
lua_replace(L, lua_upvalueindex(1)); lua_replace(L, lua_upvalueindex(1));
push_substitution(L, p->first); push_substitution(L, p->first.first);
// TODO(Leo): return postponed constraints
} else { } else {
lua_pushnil(L); lua_pushnil(L);
} }
return 1; return 1;
} }
static int push_substitution_seq_it(lua_State * L, substitution_seq const & seq) { static int push_unify_result_seq_it(lua_State * L, unify_result_seq const & seq) {
push_substitution_seq(L, seq); push_unify_result_seq(L, seq);
lua_pushcclosure(L, &safe_function<substitution_seq_next>, 1); // create closure with 1 upvalue lua_pushcclosure(L, &safe_function<unify_result_seq_next>, 1); // create closure with 1 upvalue
return 1; return 1;
} }
@ -2109,7 +2137,7 @@ static name g_tmp_prefix = name::mk_internal_unique_name();
static int unify(lua_State * L) { static int unify(lua_State * L) {
int nargs = lua_gettop(L); int nargs = lua_gettop(L);
lazy_list<substitution> r; unify_result_seq r;
environment const & env = to_environment(L, 1); environment const & env = to_environment(L, 1);
if (is_expr(L, 2)) { if (is_expr(L, 2)) {
if (nargs == 7) if (nargs == 7)
@ -2127,15 +2155,15 @@ static int unify(lua_State * L) {
else else
r = unify(env, cs.size(), cs.data(), to_name_generator(L, 3)); r = unify(env, cs.size(), cs.data(), to_name_generator(L, 3));
} }
return push_substitution_seq_it(L, r); return push_unify_result_seq_it(L, r);
} }
void open_unifier(lua_State * L) { void open_unifier(lua_State * L) {
luaL_newmetatable(L, substitution_seq_mt); luaL_newmetatable(L, unify_result_seq_mt);
lua_pushvalue(L, -1); lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index"); lua_setfield(L, -2, "__index");
setfuncs(L, substitution_seq_m, 0); setfuncs(L, unify_result_seq_m, 0);
SET_GLOBAL_FUN(substitution_seq_pred, "is_substitution_seq"); SET_GLOBAL_FUN(unify_result_seq_pred, "is_unify_result_seq");
SET_GLOBAL_FUN(unify_simple, "unify_simple"); SET_GLOBAL_FUN(unify_simple, "unify_simple");
SET_GLOBAL_FUN(unify, "unify"); SET_GLOBAL_FUN(unify, "unify");

View file

@ -37,14 +37,20 @@ struct unifier_config {
unsigned m_max_steps; unsigned m_max_steps;
bool m_computation; bool m_computation;
bool m_expensive_classes; bool m_expensive_classes;
unifier_config(bool use_exceptions = false); // If m_discard is true, then constraints that cannot be solved are discarded (or incomplete methods are used)
explicit unifier_config(options const & o, bool use_exceptions = false); // If m_discard is false, unify returns the set of constraints that could not be handled.
bool m_discard;
unifier_config(bool use_exceptions = false, bool discard = false);
explicit unifier_config(options const & o, bool use_exceptions = false, bool discard = false);
}; };
lazy_list<substitution> unify(environment const & env, unsigned num_cs, constraint const * cs, name_generator const & ngen, /** \brief The unification procedures produce a lazy list of pair substitution + constraints that could not be solved. */
unifier_config const & c = unifier_config()); typedef lazy_list<pair<substitution, constraints>> unify_result_seq;
lazy_list<substitution> unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen, bool relax_main_opaque,
substitution const & s = substitution(), unifier_config const & c = unifier_config()); unify_result_seq unify(environment const & env, unsigned num_cs, constraint const * cs, name_generator const & ngen,
unifier_config const & c = unifier_config());
unify_result_seq unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen, bool relax_main_opaque,
substitution const & s = substitution(), unifier_config const & c = unifier_config());
/** /**
The unifier divides the constraints in 8 groups: Simple, Basic, FlexRigid, PluginDelayed, DelayedChoice, ClassInstance, FlexFlex, MaxDelayed The unifier divides the constraints in 8 groups: Simple, Basic, FlexRigid, PluginDelayed, DelayedChoice, ClassInstance, FlexFlex, MaxDelayed