feat(library/unifier): add support for choice constraint
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
611f29a954
commit
228f51dcfa
2 changed files with 52 additions and 8 deletions
|
@ -38,7 +38,7 @@ typedef std::tuple<expr, justification, constraints> choice_fn_result;
|
|||
/**
|
||||
\brief A choice_fn is used to enumerate the possible solutions for a metavariable.
|
||||
The input arguments are:
|
||||
- an inferred type (if available)
|
||||
- an inferred type
|
||||
- substitution map (metavar -> value)
|
||||
- name generator
|
||||
|
||||
|
@ -49,7 +49,7 @@ typedef std::tuple<expr, justification, constraints> choice_fn_result;
|
|||
|
||||
One application of choice constraints is overloaded notation.
|
||||
*/
|
||||
typedef std::function<lazy_list<choice_fn_result>(optional<expr> const &, substitution const &, name_generator const &)> choice_fn;
|
||||
typedef std::function<lazy_list<choice_fn_result>(expr const &, substitution const &, name_generator const &)> choice_fn;
|
||||
|
||||
struct constraint_cell;
|
||||
class constraint {
|
||||
|
|
|
@ -220,6 +220,15 @@ struct unifier_fn {
|
|||
virtual bool next(unifier_fn & u) { return u.next_plugin_case_split(*this); }
|
||||
};
|
||||
|
||||
struct choice_case_split : public case_split {
|
||||
expr m_mvar;
|
||||
justification m_jst;
|
||||
lazy_list<choice_fn_result> m_tail;
|
||||
choice_case_split(unifier_fn & u, expr const & mvar, justification const & j, lazy_list<choice_fn_result> const & tail):
|
||||
case_split(u), m_mvar(mvar), m_jst(j), m_tail(tail) {}
|
||||
virtual bool next(unifier_fn & u) { return u.next_choice_case_split(*this); }
|
||||
};
|
||||
|
||||
case_split_stack m_case_splits;
|
||||
optional<justification> m_conflict;
|
||||
|
||||
|
@ -502,12 +511,6 @@ struct unifier_fn {
|
|||
return optional<substitution>();
|
||||
}
|
||||
|
||||
bool process_choice_constraint(constraint const & c) {
|
||||
lean_assert(is_choice_cnstr(c));
|
||||
// TODO(Leo)
|
||||
return true;
|
||||
}
|
||||
|
||||
// Process constraints in \c cs, and append justification \c j to them.
|
||||
bool process_constraints(constraints const & cs, justification const & j) {
|
||||
for (constraint const & c : cs)
|
||||
|
@ -515,6 +518,47 @@ struct unifier_fn {
|
|||
return !in_conflict();
|
||||
}
|
||||
|
||||
bool process_choice_result(expr const & m, choice_fn_result const & r, justification j) {
|
||||
j = mk_composite1(j, std::get<1>(r));
|
||||
return
|
||||
process_constraint(mk_eq_cnstr(m, std::get<0>(r), j)) &&
|
||||
process_constraints(std::get<2>(r), j);
|
||||
}
|
||||
|
||||
bool next_choice_case_split(choice_case_split & cs) {
|
||||
auto r = cs.m_tail.pull();
|
||||
if (r) {
|
||||
cs.restore_state(*this);
|
||||
lean_assert(!in_conflict());
|
||||
cs.m_tail = r->second;
|
||||
justification a = mk_assumption_justification(cs.m_assumption_idx);
|
||||
return process_choice_result(cs.m_mvar, r->first, mk_composite1(cs.m_jst, a));
|
||||
} else {
|
||||
// update conflict
|
||||
update_conflict(mk_composite1(*m_conflict, cs.m_failed_justifications));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool process_choice_constraint(constraint const & c) {
|
||||
lean_assert(is_choice_cnstr(c));
|
||||
expr const & m = cnstr_mvar(c);
|
||||
choice_fn const & fn = cnstr_choice_fn(c);
|
||||
lean_assert(is_meta(m));
|
||||
auto m_type_jst = m_subst.instantiate_metavars(m_tc.infer(m), nullptr, nullptr);
|
||||
auto rlist = fn(m_type_jst.first, m_subst, m_ngen.mk_child());
|
||||
auto r = rlist.pull();
|
||||
justification j = mk_composite1(c.get_justification(), m_type_jst.second);
|
||||
if (r) {
|
||||
justification a = mk_assumption_justification(m_next_assumption_idx);
|
||||
add_case_split(std::unique_ptr<case_split>(new choice_case_split(*this, m, m_type_jst.second, r->second)));
|
||||
return process_choice_result(m, r->first, mk_composite1(j, a));
|
||||
} else {
|
||||
set_conflict(j);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool next_plugin_case_split(plugin_case_split & cs) {
|
||||
auto r = cs.m_tail.pull();
|
||||
if (r) {
|
||||
|
|
Loading…
Reference in a new issue