diff --git a/src/library/unifier.cpp b/src/library/unifier.cpp
index 2b12f9393..70a34b840 100644
--- a/src/library/unifier.cpp
+++ b/src/library/unifier.cpp
@@ -242,14 +242,14 @@ struct unifier_fn {
*/
cnstr_set m_cnstrs;
/**
- \brief The following two maps are indices. The map a metavariable name \c m to the se of all constraint indices that contain \c m.
- We use these indices whenever a metavariable \c m is assigned. In this case, we used these indices to
- remove any constraint that contains \c m from \c m_cnstrs, instantiate \c m, and reprocess them.
+ \brief The following map is an index. The map a metavariable name \c m to the set of constraint indices that contain \c m.
+ We use these indices whenever a metavariable \c m is assigned.
+ When the metavariable is assigned, we used this index to remove constraints that contains \c m from \c m_cnstrs,
+ instantiate \c m, and reprocess them.
- \remark \c m_mvar_occs is for regular metavariables, and \c m_mlvl_occs is for universe metavariables.
+ \remark \c m_mvar_occs is for regular metavariables.
*/
name_to_cnstrs m_mvar_occs;
- name_to_cnstrs m_mlvl_occs;
/**
\brief Base class for the case-splits created by the unifier.
@@ -267,12 +267,11 @@ struct unifier_fn {
substitution m_subst;
cnstr_set m_cnstrs;
name_to_cnstrs m_mvar_occs;
- name_to_cnstrs m_mlvl_occs;
/** \brief Save unifier's state */
case_split(unifier_fn & u, justification const & j):
m_assumption_idx(u.m_next_assumption_idx), m_jst(j), m_subst(u.m_subst), m_cnstrs(u.m_cnstrs),
- m_mvar_occs(u.m_mvar_occs), m_mlvl_occs(u.m_mlvl_occs) {
+ m_mvar_occs(u.m_mvar_occs) {
u.m_next_assumption_idx++;
u.m_tc->push();
}
@@ -285,7 +284,6 @@ struct unifier_fn {
u.m_subst = m_subst;
u.m_cnstrs = m_cnstrs;
u.m_mvar_occs = m_mvar_occs;
- u.m_mlvl_occs = m_mlvl_occs;
m_assumption_idx = u.m_next_assumption_idx;
m_failed_justifications = mk_composite1(m_failed_justifications, *u.m_conflict);
u.m_next_assumption_idx++;
@@ -347,52 +345,43 @@ struct unifier_fn {
/**
\brief Update occurrence index with entry m -> cidx, where \c m is the name of a metavariable,
and \c cidx is the index of a constraint that contains \c m.
-
- \remark \c MVar is true if \c m is a regular metavariable, and false if it is a universe metavariable.
*/
- template
- void add_occ(name const & m, unsigned cidx) {
+ void add_mvar_occ(name const & m, unsigned cidx) {
cnstr_idx_set s;
- name_to_cnstrs & map = MVar ? m_mvar_occs : m_mlvl_occs;
- auto it = map.find(m);
+ auto it = m_mvar_occs.find(m);
if (it)
s = *it;
if (!s.contains(cidx)) {
s.insert(cidx);
- map.insert(m, s);
+ m_mvar_occs.insert(m, s);
}
}
- /** \see add_occ */
- void add_mvar_occ(name const & m, unsigned cidx) { add_occ(m, cidx); }
- /** \see add_occ */
- void add_mlvl_occ(name const & m, unsigned cidx) { add_occ(m, cidx); }
+ void add_meta_occ(expr const & m, unsigned cidx) {
+ lean_assert(is_meta(m));
+ add_mvar_occ(mlocal_name(get_app_fn(m)), cidx);
+ }
- /**
- \brief Update the indices \c m_mvar_occs and \c m_mlvl_occs.
- For every metavariable name \c m in \c mlvl_occs and \c mvar_occs, add an entry to \c cidx.
-
- \remark \c cidx is the index of some constraint in \c m_cnstrs.
- */
- void add_occs(unsigned cidx, name_set const * mlvl_occs, name_set const * mvar_occs) {
- if (mlvl_occs) {
- mlvl_occs->for_each([=](name const & m) {
- add_mlvl_occ(m, cidx);
- });
- }
- if (mvar_occs) {
- mvar_occs->for_each([=](name const & m) {
- add_mvar_occ(m, cidx);
+ void add_meta_occs(expr const & e, unsigned cidx) {
+ if (has_expr_metavar(e)) {
+ for_each(e, [&](expr const & e, unsigned) {
+ if (is_meta(e)) {
+ add_meta_occ(e, cidx);
+ return false;
+ }
+ if (is_local(e))
+ return false;
+ return has_expr_metavar(e);
});
}
}
/** \brief Add constraint to the constraint queue */
- void add_cnstr(constraint const & c, name_set const * mlvl_occs, name_set const * mvar_occs, cnstr_group g) {
+ unsigned add_cnstr(constraint const & c, cnstr_group g) {
unsigned cidx = m_next_cidx + get_group_first_index(g);
m_cnstrs.insert(cnstr(c, cidx));
- add_occs(cidx, mlvl_occs, mvar_occs);
m_next_cidx++;
+ return cidx;
}
bool is_def_eq(expr const & t1, expr const & t2, justification const & j) {
@@ -438,17 +427,7 @@ struct unifier_fn {
bool assign(level const & m, level const & v, justification const & j) {
lean_assert(is_meta(m));
m_subst.d_assign(m, v, j);
- auto it = m_mlvl_occs.find(meta_id(m));
- if (it) {
- cnstr_idx_set s = *it;
- m_mlvl_occs.erase(meta_id(m));
- s.for_each([&](unsigned cidx) {
- process_constraint_cidx(cidx);
- });
- return !in_conflict();
- } else {
- return true;
- }
+ return true;
}
enum status { Solved, Failed, Continue };
@@ -496,51 +475,95 @@ struct unifier_fn {
return is_eq_cnstr(c) && is_eq_deltas(cnstr_lhs_expr(c), cnstr_rhs_expr(c));
}
+ std::pair instantiate_metavars(constraint const & c) {
+ if (is_eq_cnstr(c)) {
+ auto lhs_jst = m_subst.instantiate_metavars(cnstr_lhs_expr(c));
+ auto rhs_jst = m_subst.instantiate_metavars(cnstr_rhs_expr(c));
+ expr lhs = lhs_jst.first;
+ expr rhs = rhs_jst.first;
+ if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c)) {
+ return mk_pair(mk_eq_cnstr(lhs, rhs, mk_composite1(mk_composite1(c.get_justification(), lhs_jst.second), rhs_jst.second)),
+ true);
+ }
+ } else if (is_level_eq_cnstr(c)) {
+ auto lhs_jst = m_subst.instantiate_metavars(cnstr_lhs_level(c));
+ auto rhs_jst = m_subst.instantiate_metavars(cnstr_rhs_level(c));
+ level lhs = lhs_jst.first;
+ level rhs = rhs_jst.first;
+ if (lhs != cnstr_lhs_level(c) || rhs != cnstr_rhs_level(c)) {
+ return mk_pair(mk_level_eq_cnstr(lhs, rhs,
+ mk_composite1(mk_composite1(c.get_justification(), lhs_jst.second), rhs_jst.second)),
+ true);
+ }
+ }
+ return mk_pair(c, false);
+ }
+
+ status process_eq_constraint_core(constraint const & c) {
+ expr const & lhs = cnstr_lhs_expr(c);
+ expr const & rhs = cnstr_rhs_expr(c);
+ justification const & jst = c.get_justification();
+
+ if (lhs == rhs)
+ return Solved; // trivial constraint
+
+ // Update justification using the justification of the instantiated metavariables
+ if (!has_metavar(lhs) && !has_metavar(rhs)) {
+ return is_def_eq(lhs, rhs, jst) ? Solved : Failed;
+ }
+
+ // Handle higher-order pattern matching.
+ status st = process_metavar_eq(lhs, rhs, jst);
+ if (st != Continue) return st;
+ st = process_metavar_eq(rhs, lhs, jst);
+ if (st != Continue) return st;
+
+ return Continue;
+ }
+
/** \brief Process an equality constraints. */
bool process_eq_constraint(constraint const & c) {
lean_assert(is_eq_cnstr(c));
// instantiate assigned metavariables
- name_set unassigned_lvls, unassigned_exprs;
- auto lhs_jst = m_subst.instantiate_metavars(cnstr_lhs_expr(c), &unassigned_lvls, &unassigned_exprs);
- auto rhs_jst = m_subst.instantiate_metavars(cnstr_rhs_expr(c), &unassigned_lvls, &unassigned_exprs);
- expr lhs = lhs_jst.first;
- expr rhs = rhs_jst.first;
-
- if (lhs == rhs)
- return true; // trivial constraint
-
- // Update justification using the justification of the instantiated metavariables
- justification new_jst = mk_composite1(mk_composite1(c.get_justification(), lhs_jst.second), rhs_jst.second);
- if (!has_metavar(lhs) && !has_metavar(rhs)) {
- return is_def_eq(lhs, rhs, new_jst);
- }
-
- // Handle higher-order pattern matching.
- status st = process_metavar_eq(lhs, rhs, new_jst);
- if (st != Continue) return st == Solved;
- st = process_metavar_eq(rhs, lhs, new_jst);
+ auto r = instantiate_metavars(c);
+ constraint const & new_c = r.first;
+ bool modified = r.second;
+ status st = process_eq_constraint_core(new_c);
if (st != Continue) return st == Solved;
+ expr const & lhs = cnstr_lhs_expr(new_c);
+ expr const & rhs = cnstr_rhs_expr(new_c);
+ justification const & jst = new_c.get_justification();
+
// If lhs or rhs were updated, then invoke is_def_eq again.
- if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c)) {
+ if (modified) {
// some metavariables were instantiated, try is_def_eq again
- return is_def_eq(lhs, rhs, new_jst);
+ return is_def_eq(lhs, rhs, jst);
}
if (is_eq_deltas(lhs, rhs)) {
// we need to create a backtracking point for this one
- add_cnstr(c, &unassigned_lvls, &unassigned_exprs, cnstr_group::Basic);
- } else if (m_plugin->delay_constraint(*m_tc, c)) {
- add_cnstr(c, &unassigned_lvls, &unassigned_exprs, cnstr_group::PluginDelayed);
+ add_cnstr(new_c, cnstr_group::Basic);
+ } else if (m_plugin->delay_constraint(*m_tc, new_c)) {
+ unsigned cidx = add_cnstr(new_c, cnstr_group::PluginDelayed);
+ add_meta_occs(lhs, cidx);
+ add_meta_occs(rhs, cidx);
} else if (is_meta(lhs) && is_meta(rhs)) {
// flex-flex constraints are delayed the most.
- add_cnstr(c, &unassigned_lvls, &unassigned_exprs, cnstr_group::FlexFlex);
- } else if (is_meta(lhs) || is_meta(rhs)) {
+ unsigned cidx = add_cnstr(new_c, cnstr_group::FlexFlex);
+ add_meta_occ(lhs, cidx);
+ add_meta_occ(rhs, cidx);
+ } else if (is_meta(lhs)) {
// flex-rigid constraints are delayed.
- add_cnstr(c, &unassigned_lvls, &unassigned_exprs, cnstr_group::FlexRigid);
+ unsigned cidx = add_cnstr(new_c, cnstr_group::FlexRigid);
+ add_meta_occ(lhs, cidx);
+ } else if (is_meta(rhs)) {
+ // flex-rigid constraints are delayed.
+ unsigned cidx = add_cnstr(new_c, cnstr_group::FlexRigid);
+ add_meta_occ(rhs, cidx);
} else {
// this constraints require the unifier plugin to be solved
- add_cnstr(c, &unassigned_lvls, &unassigned_exprs, cnstr_group::Basic);
+ add_cnstr(new_c, cnstr_group::Basic);
}
return true;
}
@@ -573,11 +596,10 @@ struct unifier_fn {
bool process_level_eq_constraint(constraint const & c) {
lean_assert(is_level_eq_cnstr(c));
// instantiate assigned metavariables
- name_set unassigned_lvls;
- auto lhs_jst = m_subst.instantiate_metavars(cnstr_lhs_level(c), &unassigned_lvls);
- auto rhs_jst = m_subst.instantiate_metavars(cnstr_rhs_level(c), &unassigned_lvls);
- level lhs = lhs_jst.first;
- level rhs = rhs_jst.first;
+ constraint new_c = instantiate_metavars(c).first;
+ level lhs = cnstr_lhs_level(new_c);
+ level rhs = cnstr_rhs_level(new_c);
+ justification jst = new_c.get_justification();
// normalize lhs and rhs
lhs = normalize(lhs);
@@ -591,24 +613,17 @@ struct unifier_fn {
if (lhs == rhs)
return true; // trivial constraint
- justification new_jst = mk_composite1(mk_composite1(c.get_justification(), lhs_jst.second), rhs_jst.second);
if (!has_meta(lhs) && !has_meta(rhs)) {
- set_conflict(new_jst);
+ set_conflict(jst);
return false; // trivial failure
}
- status st = process_metavar_eq(lhs, rhs, new_jst);
+ status st = process_metavar_eq(lhs, rhs, jst);
if (st != Continue) return st == Solved;
- st = process_metavar_eq(rhs, lhs, new_jst);
+ st = process_metavar_eq(rhs, lhs, jst);
if (st != Continue) return st == Solved;
- if (lhs != cnstr_lhs_level(c) || rhs != cnstr_rhs_level(c)) {
- constraint new_c = mk_level_eq_cnstr(lhs, rhs, new_jst);
- add_cnstr(new_c, &unassigned_lvls, nullptr, cnstr_group::FlexRigid);
- } else {
- add_cnstr(c, &unassigned_lvls, nullptr, cnstr_group::FlexRigid);
- }
-
+ add_cnstr(new_c, cnstr_group::FlexRigid);
return true;
}
@@ -624,7 +639,7 @@ struct unifier_fn {
switch (c.kind()) {
case constraint_kind::Choice:
// Choice constraints are never considered easy.
- add_cnstr(c, nullptr, nullptr, get_choice_cnstr_group(c));
+ add_cnstr(c, get_choice_cnstr_group(c));
return true;
case constraint_kind::Eq:
return process_eq_constraint(c);
@@ -1163,16 +1178,35 @@ struct unifier_fn {
lean_assert(!m_cnstrs.empty());
constraint c = m_cnstrs.min()->first;
m_cnstrs.erase_min();
- if (is_choice_cnstr(c))
+ if (is_choice_cnstr(c)) {
return process_choice_constraint(c);
- else if (is_delta_cnstr(c))
- return process_delta(c);
- else if (is_flex_rigid(c))
- return process_flex_rigid(c);
- else if (is_flex_flex(c))
- return process_flex_flex(c);
- else
- return process_plugin_constraint(c);
+ } else {
+ auto r = instantiate_metavars(c);
+ c = r.first;
+ bool modified = r.second;
+ if (is_level_eq_cnstr(c)) {
+ if (modified)
+ return process_constraint(c);
+ else
+ return process_plugin_constraint(c);
+ } else {
+ lean_assert(is_eq_cnstr(c));
+ if (modified) {
+ status st = process_eq_constraint_core(c);
+ if (st != Continue) return st == Solved;
+ }
+ if (is_delta_cnstr(c))
+ return process_delta(c);
+ else if (is_flex_rigid(c))
+ return process_flex_rigid(c);
+ else if (is_flex_flex(c))
+ return process_flex_flex(c);
+ else if (modified)
+ return is_def_eq(cnstr_lhs_expr(c), cnstr_rhs_expr(c), c.get_justification());
+ else
+ return process_plugin_constraint(c);
+ }
+ }
}
/** \brief Return true if unifier may be able to produce more solutions */