perf(library/unifier): minimize the use of instantiate_metavars

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-07-15 02:34:27 +01:00
parent 29c7eeaa99
commit a18cf94d09
2 changed files with 87 additions and 47 deletions

View file

@ -26,9 +26,6 @@ class substitution {
friend class instantiate_metavars_fn;
std::pair<level, justification> d_instantiate_metavars(level const & l, bool use_jst, bool updt);
justification get_expr_jst(name const & m) const { if (auto it = m_expr_jsts.find(m)) return *it; else return justification(); }
justification get_level_jst(name const & m) const { if (auto it = m_level_jsts.find(m)) return *it; else return justification(); }
public:
substitution();
typedef optional<std::pair<expr, justification>> opt_expr_jst;
@ -42,6 +39,8 @@ public:
optional<expr> get_expr(name const & m) const;
optional<level> get_level(name const & m) const;
justification get_expr_jst(name const & m) const { if (auto it = m_expr_jsts.find(m)) return *it; else return justification(); }
justification get_level_jst(name const & m) const { if (auto it = m_level_jsts.find(m)) return *it; else return justification(); }
substitution assign(name const & m, expr const & t, justification const & j) const;
substitution assign(name const & m, expr const & t) const;

View file

@ -69,8 +69,6 @@ bool context_check(expr const & e, buffer<expr> const & locals) {
// - l_false if \c e contains \c m or it contains a local constant \c l
// not in locals that is not in a metavariable application.
lbool occurs_context_check(substitution const & s, expr const & e, expr const & m, buffer<expr> const & locals) {
if (s.occurs(m, e))
return l_false;
expr root = e;
lbool r = l_true;
for_each(e, [&](expr const & e, unsigned) {
@ -84,8 +82,10 @@ lbool occurs_context_check(substitution const & s, expr const & e, expr const &
}
return false; // do not visit type
} else if (is_meta(e)) {
if (!context_check(e, locals))
if (!context_check(e, locals) || s.occurs(m, e))
r = l_undef;
if (get_app_fn(e) == m)
r = l_false;
return false; // do not visit children
} else {
// we only need to continue exploring e if it contains
@ -524,49 +524,87 @@ struct unifier_fn {
return Continue;
}
expr instantiate_meta(expr const & e, justification & j) {
expr const & f = get_app_fn(e);
if (!is_metavar(f))
return e;
auto r = m_subst.d_instantiate_metavars(f);
if (is_metavar(r.first))
return e;
buffer<expr> args;
get_app_rev_args(e, args);
j = mk_composite1(j, r.second);
return apply_beta(r.first, args.size(), args.data());
}
expr instantiate_meta_args(expr const & e, justification & j) {
if (!is_app(e))
return e;
buffer<expr> args;
bool modified = false;
expr const & f = get_app_rev_args(e, args);
unsigned i = args.size();
while (i > 0) {
--i;
expr new_arg = instantiate_meta(args[i], j);
if (new_arg != args[i]) {
modified = true;
args[i] = new_arg;
}
}
if (!modified)
return e;
return mk_rev_app(f, args.size(), args.data());
}
status instantiate_eq_cnstr(constraint const & c) {
justification j = c.get_justification();
expr lhs = instantiate_meta(cnstr_lhs_expr(c), j);
expr rhs = instantiate_meta(cnstr_rhs_expr(c), j);
if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c))
return is_def_eq(lhs, rhs, j) ? Solved : Failed;
lhs = instantiate_meta_args(lhs, j);
rhs = instantiate_meta_args(rhs, j);
if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c))
return is_def_eq(lhs, rhs, j) ? Solved : Failed;
return Continue;
}
/** \brief Process an equality constraints. */
bool process_eq_constraint(constraint const & c) {
lean_assert(is_eq_cnstr(c));
// instantiate assigned metavariables
auto r = instantiate_metavars(c);
constraint const & new_c = r.first;
bool modified = r.second;
status st = process_eq_constraint_core(new_c);
status st = instantiate_eq_cnstr(c);
if (st != Continue) return st == Solved;
st = process_eq_constraint_core(c);
if (st != Continue) return st == Solved;
expr const & lhs = cnstr_lhs_expr(new_c);
expr const & rhs = cnstr_rhs_expr(new_c);
justification const & jst = new_c.get_justification();
// If lhs or rhs were updated, then invoke is_def_eq again.
if (modified) {
// some metavariables were instantiated, try is_def_eq again
return is_def_eq(lhs, rhs, jst);
}
expr const & lhs = cnstr_lhs_expr(c);
expr const & rhs = cnstr_rhs_expr(c);
if (is_eq_deltas(lhs, rhs)) {
// we need to create a backtracking point for this one
add_cnstr(new_c, cnstr_group::Basic);
} else if (m_plugin->delay_constraint(*m_tc, new_c)) {
unsigned cidx = add_cnstr(new_c, cnstr_group::PluginDelayed);
add_cnstr(c, cnstr_group::Basic);
} else if (m_plugin->delay_constraint(*m_tc, c)) {
unsigned cidx = add_cnstr(c, cnstr_group::PluginDelayed);
add_meta_occs(lhs, cidx);
add_meta_occs(rhs, cidx);
} else if (is_meta(lhs) && is_meta(rhs)) {
// flex-flex constraints are delayed the most.
unsigned cidx = add_cnstr(new_c, cnstr_group::FlexFlex);
unsigned cidx = add_cnstr(c, cnstr_group::FlexFlex);
add_meta_occ(lhs, cidx);
add_meta_occ(rhs, cidx);
} else if (is_meta(lhs)) {
// flex-rigid constraints are delayed.
unsigned cidx = add_cnstr(new_c, cnstr_group::FlexRigid);
unsigned cidx = add_cnstr(c, cnstr_group::FlexRigid);
add_meta_occ(lhs, cidx);
} else if (is_meta(rhs)) {
// flex-rigid constraints are delayed.
unsigned cidx = add_cnstr(new_c, cnstr_group::FlexRigid);
unsigned cidx = add_cnstr(c, cnstr_group::FlexRigid);
add_meta_occ(rhs, cidx);
} else {
// this constraints require the unifier plugin to be solved
add_cnstr(new_c, cnstr_group::Basic);
add_cnstr(c, cnstr_group::Basic);
}
return true;
}
@ -806,28 +844,35 @@ struct unifier_fn {
// the constraint is not well-formed, this can happen when users are abusing the API
return false;
}
buffer<constraint> cs_buffer;
while (!is_nil(lhs_lvls)) {
cs_buffer.push_back(mk_level_eq_cnstr(head(lhs_lvls), head(rhs_lvls), j));
lhs_lvls = tail(lhs_lvls);
rhs_lvls = tail(rhs_lvls);
}
unsigned i = lhs_args.size();
while (i > 0) {
--i;
cs_buffer.push_back(mk_eq_cnstr(lhs_args[i], rhs_args[i], j));
}
constraints cs1 = to_list(cs_buffer.begin(), cs_buffer.end());
justification a = mk_assumption_justification(m_next_assumption_idx);
// add case_split for t =?= s
expr lhs_fn_val = instantiate_univ_params(d.get_value(), d.get_univ_params(), const_levels(lhs_fn));
expr rhs_fn_val = instantiate_univ_params(d.get_value(), d.get_univ_params(), const_levels(rhs_fn));
expr t = apply_beta(lhs_fn_val, lhs_args.size(), lhs_args.data());
expr s = apply_beta(rhs_fn_val, rhs_args.size(), rhs_args.data());
constraints cs2(mk_eq_cnstr(t, s, j));
justification a = mk_assumption_justification(m_next_assumption_idx);
add_case_split(std::unique_ptr<case_split>(new simple_case_split(*this, j, list<constraints>(cs2))));
return process_constraints(cs1, a);
// process first case
justification new_j = mk_composite1(j, a);
while (!is_nil(lhs_lvls)) {
level lhs = head(lhs_lvls);
level rhs = head(rhs_lvls);
if (!process_constraint(mk_level_eq_cnstr(lhs, rhs, new_j)))
return false;
lhs_lvls = tail(lhs_lvls);
rhs_lvls = tail(rhs_lvls);
}
unsigned i = lhs_args.size();
while (i > 0) {
--i;
if (!is_def_eq(lhs_args[i], rhs_args[i], new_j))
return false;
}
return true;
}
/** \brief Return true iff \c c is a flex-rigid constraint. */
@ -1197,18 +1242,14 @@ struct unifier_fn {
return process_plugin_constraint(c);
} else {
lean_assert(is_eq_cnstr(c));
if (modified) {
status st = process_eq_constraint_core(c);
if (st != Continue) return st == Solved;
}
if (is_delta_cnstr(c))
return process_delta(c);
else if (modified)
return is_def_eq(cnstr_lhs_expr(c), cnstr_rhs_expr(c), c.get_justification());
else if (is_flex_rigid(c))
return process_flex_rigid(c);
else if (is_flex_flex(c))
return process_flex_flex(c);
else if (modified)
return is_def_eq(cnstr_lhs_expr(c), cnstr_rhs_expr(c), c.get_justification());
else
return process_plugin_constraint(c);
}