perf(library/unifier): improve m_mvar_occs management

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-07-12 04:23:02 +01:00
parent c3e8e83e50
commit 50f76fd138

View file

@ -242,14 +242,14 @@ struct unifier_fn {
*/
cnstr_set m_cnstrs;
/**
\brief The following two maps are indices. The map a metavariable name \c m to the se of all constraint indices that contain \c m.
We use these indices whenever a metavariable \c m is assigned. In this case, we used these indices to
remove any constraint that contains \c m from \c m_cnstrs, instantiate \c m, and reprocess them.
\brief The following map is an index. The map a metavariable name \c m to the set of constraint indices that contain \c m.
We use these indices whenever a metavariable \c m is assigned.
When the metavariable is assigned, we used this index to remove constraints that contains \c m from \c m_cnstrs,
instantiate \c m, and reprocess them.
\remark \c m_mvar_occs is for regular metavariables, and \c m_mlvl_occs is for universe metavariables.
\remark \c m_mvar_occs is for regular metavariables.
*/
name_to_cnstrs m_mvar_occs;
name_to_cnstrs m_mlvl_occs;
/**
\brief Base class for the case-splits created by the unifier.
@ -267,12 +267,11 @@ struct unifier_fn {
substitution m_subst;
cnstr_set m_cnstrs;
name_to_cnstrs m_mvar_occs;
name_to_cnstrs m_mlvl_occs;
/** \brief Save unifier's state */
case_split(unifier_fn & u, justification const & j):
m_assumption_idx(u.m_next_assumption_idx), m_jst(j), m_subst(u.m_subst), m_cnstrs(u.m_cnstrs),
m_mvar_occs(u.m_mvar_occs), m_mlvl_occs(u.m_mlvl_occs) {
m_mvar_occs(u.m_mvar_occs) {
u.m_next_assumption_idx++;
u.m_tc->push();
}
@ -285,7 +284,6 @@ struct unifier_fn {
u.m_subst = m_subst;
u.m_cnstrs = m_cnstrs;
u.m_mvar_occs = m_mvar_occs;
u.m_mlvl_occs = m_mlvl_occs;
m_assumption_idx = u.m_next_assumption_idx;
m_failed_justifications = mk_composite1(m_failed_justifications, *u.m_conflict);
u.m_next_assumption_idx++;
@ -347,52 +345,43 @@ struct unifier_fn {
/**
\brief Update occurrence index with entry <tt>m -> cidx</tt>, where \c m is the name of a metavariable,
and \c cidx is the index of a constraint that contains \c m.
\remark \c MVar is true if \c m is a regular metavariable, and false if it is a universe metavariable.
*/
template<bool MVar>
void add_occ(name const & m, unsigned cidx) {
void add_mvar_occ(name const & m, unsigned cidx) {
cnstr_idx_set s;
name_to_cnstrs & map = MVar ? m_mvar_occs : m_mlvl_occs;
auto it = map.find(m);
auto it = m_mvar_occs.find(m);
if (it)
s = *it;
if (!s.contains(cidx)) {
s.insert(cidx);
map.insert(m, s);
m_mvar_occs.insert(m, s);
}
}
/** \see add_occ */
void add_mvar_occ(name const & m, unsigned cidx) { add_occ<true>(m, cidx); }
/** \see add_occ */
void add_mlvl_occ(name const & m, unsigned cidx) { add_occ<false>(m, cidx); }
void add_meta_occ(expr const & m, unsigned cidx) {
lean_assert(is_meta(m));
add_mvar_occ(mlocal_name(get_app_fn(m)), cidx);
}
/**
\brief Update the indices \c m_mvar_occs and \c m_mlvl_occs.
For every metavariable name \c m in \c mlvl_occs and \c mvar_occs, add an entry to \c cidx.
\remark \c cidx is the index of some constraint in \c m_cnstrs.
*/
void add_occs(unsigned cidx, name_set const * mlvl_occs, name_set const * mvar_occs) {
if (mlvl_occs) {
mlvl_occs->for_each([=](name const & m) {
add_mlvl_occ(m, cidx);
});
}
if (mvar_occs) {
mvar_occs->for_each([=](name const & m) {
add_mvar_occ(m, cidx);
void add_meta_occs(expr const & e, unsigned cidx) {
if (has_expr_metavar(e)) {
for_each(e, [&](expr const & e, unsigned) {
if (is_meta(e)) {
add_meta_occ(e, cidx);
return false;
}
if (is_local(e))
return false;
return has_expr_metavar(e);
});
}
}
/** \brief Add constraint to the constraint queue */
void add_cnstr(constraint const & c, name_set const * mlvl_occs, name_set const * mvar_occs, cnstr_group g) {
unsigned add_cnstr(constraint const & c, cnstr_group g) {
unsigned cidx = m_next_cidx + get_group_first_index(g);
m_cnstrs.insert(cnstr(c, cidx));
add_occs(cidx, mlvl_occs, mvar_occs);
m_next_cidx++;
return cidx;
}
bool is_def_eq(expr const & t1, expr const & t2, justification const & j) {
@ -438,17 +427,7 @@ struct unifier_fn {
bool assign(level const & m, level const & v, justification const & j) {
lean_assert(is_meta(m));
m_subst.d_assign(m, v, j);
auto it = m_mlvl_occs.find(meta_id(m));
if (it) {
cnstr_idx_set s = *it;
m_mlvl_occs.erase(meta_id(m));
s.for_each([&](unsigned cidx) {
process_constraint_cidx(cidx);
});
return !in_conflict();
} else {
return true;
}
return true;
}
enum status { Solved, Failed, Continue };
@ -496,51 +475,95 @@ struct unifier_fn {
return is_eq_cnstr(c) && is_eq_deltas(cnstr_lhs_expr(c), cnstr_rhs_expr(c));
}
std::pair<constraint, bool> instantiate_metavars(constraint const & c) {
if (is_eq_cnstr(c)) {
auto lhs_jst = m_subst.instantiate_metavars(cnstr_lhs_expr(c));
auto rhs_jst = m_subst.instantiate_metavars(cnstr_rhs_expr(c));
expr lhs = lhs_jst.first;
expr rhs = rhs_jst.first;
if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c)) {
return mk_pair(mk_eq_cnstr(lhs, rhs, mk_composite1(mk_composite1(c.get_justification(), lhs_jst.second), rhs_jst.second)),
true);
}
} else if (is_level_eq_cnstr(c)) {
auto lhs_jst = m_subst.instantiate_metavars(cnstr_lhs_level(c));
auto rhs_jst = m_subst.instantiate_metavars(cnstr_rhs_level(c));
level lhs = lhs_jst.first;
level rhs = rhs_jst.first;
if (lhs != cnstr_lhs_level(c) || rhs != cnstr_rhs_level(c)) {
return mk_pair(mk_level_eq_cnstr(lhs, rhs,
mk_composite1(mk_composite1(c.get_justification(), lhs_jst.second), rhs_jst.second)),
true);
}
}
return mk_pair(c, false);
}
status process_eq_constraint_core(constraint const & c) {
expr const & lhs = cnstr_lhs_expr(c);
expr const & rhs = cnstr_rhs_expr(c);
justification const & jst = c.get_justification();
if (lhs == rhs)
return Solved; // trivial constraint
// Update justification using the justification of the instantiated metavariables
if (!has_metavar(lhs) && !has_metavar(rhs)) {
return is_def_eq(lhs, rhs, jst) ? Solved : Failed;
}
// Handle higher-order pattern matching.
status st = process_metavar_eq(lhs, rhs, jst);
if (st != Continue) return st;
st = process_metavar_eq(rhs, lhs, jst);
if (st != Continue) return st;
return Continue;
}
/** \brief Process an equality constraints. */
bool process_eq_constraint(constraint const & c) {
lean_assert(is_eq_cnstr(c));
// instantiate assigned metavariables
name_set unassigned_lvls, unassigned_exprs;
auto lhs_jst = m_subst.instantiate_metavars(cnstr_lhs_expr(c), &unassigned_lvls, &unassigned_exprs);
auto rhs_jst = m_subst.instantiate_metavars(cnstr_rhs_expr(c), &unassigned_lvls, &unassigned_exprs);
expr lhs = lhs_jst.first;
expr rhs = rhs_jst.first;
if (lhs == rhs)
return true; // trivial constraint
// Update justification using the justification of the instantiated metavariables
justification new_jst = mk_composite1(mk_composite1(c.get_justification(), lhs_jst.second), rhs_jst.second);
if (!has_metavar(lhs) && !has_metavar(rhs)) {
return is_def_eq(lhs, rhs, new_jst);
}
// Handle higher-order pattern matching.
status st = process_metavar_eq(lhs, rhs, new_jst);
if (st != Continue) return st == Solved;
st = process_metavar_eq(rhs, lhs, new_jst);
auto r = instantiate_metavars(c);
constraint const & new_c = r.first;
bool modified = r.second;
status st = process_eq_constraint_core(new_c);
if (st != Continue) return st == Solved;
expr const & lhs = cnstr_lhs_expr(new_c);
expr const & rhs = cnstr_rhs_expr(new_c);
justification const & jst = new_c.get_justification();
// If lhs or rhs were updated, then invoke is_def_eq again.
if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c)) {
if (modified) {
// some metavariables were instantiated, try is_def_eq again
return is_def_eq(lhs, rhs, new_jst);
return is_def_eq(lhs, rhs, jst);
}
if (is_eq_deltas(lhs, rhs)) {
// we need to create a backtracking point for this one
add_cnstr(c, &unassigned_lvls, &unassigned_exprs, cnstr_group::Basic);
} else if (m_plugin->delay_constraint(*m_tc, c)) {
add_cnstr(c, &unassigned_lvls, &unassigned_exprs, cnstr_group::PluginDelayed);
add_cnstr(new_c, cnstr_group::Basic);
} else if (m_plugin->delay_constraint(*m_tc, new_c)) {
unsigned cidx = add_cnstr(new_c, cnstr_group::PluginDelayed);
add_meta_occs(lhs, cidx);
add_meta_occs(rhs, cidx);
} else if (is_meta(lhs) && is_meta(rhs)) {
// flex-flex constraints are delayed the most.
add_cnstr(c, &unassigned_lvls, &unassigned_exprs, cnstr_group::FlexFlex);
} else if (is_meta(lhs) || is_meta(rhs)) {
unsigned cidx = add_cnstr(new_c, cnstr_group::FlexFlex);
add_meta_occ(lhs, cidx);
add_meta_occ(rhs, cidx);
} else if (is_meta(lhs)) {
// flex-rigid constraints are delayed.
add_cnstr(c, &unassigned_lvls, &unassigned_exprs, cnstr_group::FlexRigid);
unsigned cidx = add_cnstr(new_c, cnstr_group::FlexRigid);
add_meta_occ(lhs, cidx);
} else if (is_meta(rhs)) {
// flex-rigid constraints are delayed.
unsigned cidx = add_cnstr(new_c, cnstr_group::FlexRigid);
add_meta_occ(rhs, cidx);
} else {
// this constraints require the unifier plugin to be solved
add_cnstr(c, &unassigned_lvls, &unassigned_exprs, cnstr_group::Basic);
add_cnstr(new_c, cnstr_group::Basic);
}
return true;
}
@ -573,11 +596,10 @@ struct unifier_fn {
bool process_level_eq_constraint(constraint const & c) {
lean_assert(is_level_eq_cnstr(c));
// instantiate assigned metavariables
name_set unassigned_lvls;
auto lhs_jst = m_subst.instantiate_metavars(cnstr_lhs_level(c), &unassigned_lvls);
auto rhs_jst = m_subst.instantiate_metavars(cnstr_rhs_level(c), &unassigned_lvls);
level lhs = lhs_jst.first;
level rhs = rhs_jst.first;
constraint new_c = instantiate_metavars(c).first;
level lhs = cnstr_lhs_level(new_c);
level rhs = cnstr_rhs_level(new_c);
justification jst = new_c.get_justification();
// normalize lhs and rhs
lhs = normalize(lhs);
@ -591,24 +613,17 @@ struct unifier_fn {
if (lhs == rhs)
return true; // trivial constraint
justification new_jst = mk_composite1(mk_composite1(c.get_justification(), lhs_jst.second), rhs_jst.second);
if (!has_meta(lhs) && !has_meta(rhs)) {
set_conflict(new_jst);
set_conflict(jst);
return false; // trivial failure
}
status st = process_metavar_eq(lhs, rhs, new_jst);
status st = process_metavar_eq(lhs, rhs, jst);
if (st != Continue) return st == Solved;
st = process_metavar_eq(rhs, lhs, new_jst);
st = process_metavar_eq(rhs, lhs, jst);
if (st != Continue) return st == Solved;
if (lhs != cnstr_lhs_level(c) || rhs != cnstr_rhs_level(c)) {
constraint new_c = mk_level_eq_cnstr(lhs, rhs, new_jst);
add_cnstr(new_c, &unassigned_lvls, nullptr, cnstr_group::FlexRigid);
} else {
add_cnstr(c, &unassigned_lvls, nullptr, cnstr_group::FlexRigid);
}
add_cnstr(new_c, cnstr_group::FlexRigid);
return true;
}
@ -624,7 +639,7 @@ struct unifier_fn {
switch (c.kind()) {
case constraint_kind::Choice:
// Choice constraints are never considered easy.
add_cnstr(c, nullptr, nullptr, get_choice_cnstr_group(c));
add_cnstr(c, get_choice_cnstr_group(c));
return true;
case constraint_kind::Eq:
return process_eq_constraint(c);
@ -1163,16 +1178,35 @@ struct unifier_fn {
lean_assert(!m_cnstrs.empty());
constraint c = m_cnstrs.min()->first;
m_cnstrs.erase_min();
if (is_choice_cnstr(c))
if (is_choice_cnstr(c)) {
return process_choice_constraint(c);
else if (is_delta_cnstr(c))
return process_delta(c);
else if (is_flex_rigid(c))
return process_flex_rigid(c);
else if (is_flex_flex(c))
return process_flex_flex(c);
else
return process_plugin_constraint(c);
} else {
auto r = instantiate_metavars(c);
c = r.first;
bool modified = r.second;
if (is_level_eq_cnstr(c)) {
if (modified)
return process_constraint(c);
else
return process_plugin_constraint(c);
} else {
lean_assert(is_eq_cnstr(c));
if (modified) {
status st = process_eq_constraint_core(c);
if (st != Continue) return st == Solved;
}
if (is_delta_cnstr(c))
return process_delta(c);
else if (is_flex_rigid(c))
return process_flex_rigid(c);
else if (is_flex_flex(c))
return process_flex_flex(c);
else if (modified)
return is_def_eq(cnstr_lhs_expr(c), cnstr_rhs_expr(c), c.get_justification());
else
return process_plugin_constraint(c);
}
}
}
/** \brief Return true if unifier may be able to produce more solutions */