perf(library/unifier): improve flex_rigid performance

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-08-08 19:18:45 -07:00
parent 9d13f634f3
commit 0af55beb56

View file

@ -12,6 +12,7 @@ Author: Leonardo de Moura
#include "util/lazy_list_fn.h"
#include "util/sstream.h"
#include "util/lbool.h"
#include "util/flet.h"
#include "kernel/for_each_fn.h"
#include "kernel/abstract.h"
#include "kernel/instantiate.h"
@ -437,6 +438,7 @@ struct unifier_fn {
if (m_tc[relax]->is_def_eq(t1, t2, j)) {
return true;
} else {
// std::cout << "conflict: " << t1 << " =?= " << t2 << "\n";
set_conflict(j);
return false;
}
@ -894,6 +896,7 @@ struct unifier_fn {
if (in_conflict())
return false;
check_system();
// std::cout << "process: " << c << "\n";
switch (c.kind()) {
case constraint_kind::Choice:
return preprocess_choice_constraint(c);
@ -1178,11 +1181,20 @@ struct unifier_fn {
buffer<expr> margs;
expr const & m;
expr const & rhs;
justification const & j;
justification j;
bool relax;
buffer<constraints> & alts; // result: alternatives
optional<bool> _has_meta_args;
/** \brief Return true if margs contains an expression \c e s.t. is_meta(e) */
bool has_meta_args() {
if (!_has_meta_args) {
_has_meta_args = std::any_of(margs.begin(), margs.end(),
[](expr const & e) { return is_meta(e); });
}
return *_has_meta_args;
}
/**
\brief Given t
<tt>Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), B[x_1, ..., x_n]</tt>
@ -1191,14 +1203,31 @@ struct unifier_fn {
\remark v has free variables.
*/
static expr mk_lambda_for(expr const & t, expr const & v) {
if (is_pi(t)) {
return mk_lambda(binding_name(t), binding_domain(t), mk_lambda_for(binding_body(t), v), binding_info(t));
expr mk_lambda_for(unsigned i, expr const & t, expr const & v) {
if (i < margs.size()) {
return mk_lambda(binding_name(t), binding_domain(t), mk_lambda_for(i+1, binding_body(t), v), binding_info(t));
} else {
return v;
}
}
expr mk_lambda_for(expr const & t, expr const & v) {
return mk_lambda_for(0, t, v);
}
/** \brief Return true if \c local occurs once in the buffer \c es. */
bool local_occurs_once(expr const & local, buffer<expr> const & es) {
bool found = false;
for (expr const & e : es) {
if (is_local(e) && mlocal_name(e) == mlocal_name(local)) {
if (found)
return false;
found = true;
}
}
return true;
}
/** \brief Copy pending constraints in u.m_tc[relax] to cs and append justification j to them */
void copy_pending_constraints(buffer<constraint> & cs) {
while (auto c = u.m_tc[relax]->next_cnstr())
@ -1247,7 +1276,8 @@ struct unifier_fn {
lean_assert(is_sort(rhs) || is_constant(rhs));
expr const & mtype = mlocal_type(m);
buffer<constraint> cs;
cs.push_back(mk_eq_cnstr(m, mk_lambda_for(mtype, rhs), j, relax));
auto new_mtype = ensure_sufficient_args(mtype, cs);
cs.push_back(mk_eq_cnstr(m, mk_lambda_for(new_mtype, rhs), j, relax));
alts.push_back(to_list(cs.begin(), cs.end()));
}
@ -1377,48 +1407,136 @@ struct unifier_fn {
}
}
void mk_app_projections() {
lean_assert(is_metavar(m));
lean_assert(is_app(rhs));
if (!u.m_pattern) {
expr const & f = get_app_fn(rhs);
lean_assert(is_constant(f) || is_local(f));
if (is_local(f)) {
unsigned i = margs.size();
if (!u.m_pattern) {
while (i > 0) {
--i;
if (!(is_local(margs[i]) && mlocal_name(margs[i]) == mlocal_name(f)))
mk_simple_nonlocal_projection(i);
}
}
} else {
mk_simple_projections();
}
}
}
/** \brief Create the local context \c locals for the imitiation step.
*/
void mk_local_context(buffer<expr> & locals, buffer<constraint> & cs) {
expr mtype = mlocal_type(m);
unsigned nargs = margs.size();
mtype = ensure_sufficient_args(mtype, cs);
expr it = mtype;
for (unsigned i = 0; i < nargs; i++) {
expr d = instantiate_rev(binding_domain(it), locals.size(), locals.data());
auto d_jst = u.m_subst.instantiate_metavars(d);
d = d_jst.first;
j = mk_composite1(j, d_jst.second);
name n;
if (is_local(margs[i]) && local_occurs_once(margs[i], margs)) {
n = mlocal_name(margs[i]);
} else {
n = u.m_ngen.next();
}
expr local = mk_local(n, binding_name(it), d, binding_info(it));
locals.push_back(local);
it = binding_body(it);
}
}
expr mk_imitiation_arg(expr const & arg, expr const & type, buffer<expr> const & locals,
buffer<constraint> & cs) {
if (!has_meta_args() && is_local(arg) && contains_local(arg, locals)) {
return arg;
} else {
// std::cout << "type: " << type << "\n";
if (context_check(type, locals)) {
expr maux = mk_metavar(u.m_ngen.next(), Pi(locals, type));
// std::cout << " >> " << maux << " : " << mlocal_type(maux) << "\n";
cs.push_back(mk_eq_cnstr(mk_app(maux, margs), arg, j, relax));
return mk_app(maux, locals);
} else {
expr maux_type = mk_metavar(u.m_ngen.next(), Pi(locals, mk_sort(mk_meta_univ(u.m_ngen.next()))));
expr maux = mk_metavar(u.m_ngen.next(), Pi(locals, mk_app(maux_type, locals)));
cs.push_back(mk_eq_cnstr(mk_app(maux_type, locals), type, j, relax));
cs.push_back(mk_eq_cnstr(mk_app(maux, margs), arg, j, relax));
return mk_app(maux, locals);
}
}
}
void mk_app_imitation_core(expr const & f, buffer<expr> const & locals, buffer<constraint> & cs) {
buffer<expr> rargs;
get_app_args(rhs, rargs);
buffer<expr> sargs;
try {
// create a scope to make sure no constraints "leak" into the current state
type_checker::scope scope(*u.m_tc[relax]);
expr f_type = u.m_tc[relax]->infer(f);
for (expr const & rarg : rargs) {
f_type = u.m_tc[relax]->ensure_pi(f_type);
expr d_type = binding_domain(f_type);
expr sarg = mk_imitiation_arg(rarg, d_type, locals, cs);
sargs.push_back(sarg);
f_type = instantiate(binding_body(f_type), sarg);
}
copy_pending_constraints(cs);
} catch (exception&) {}
expr v = Fun(locals, mk_app(f, sargs));
// std::cout << " >> app imitation, v: " << v << "\n";
lean_assert(!has_local(v));
cs.push_back(mk_eq_cnstr(m, v, j, relax));
alts.push_back(to_list(cs.begin(), cs.end()));
}
/**
\brief Given
m := a metavariable ?m
margs := [a_1 ... a_k]
rhs := (g b_1 ... b_n)
rhs := (f b_1 ... b_n)
Then create the constraints
(?m_1 a_1 ... a_k) =?= b_1
...
(?m_n a_1 ... a_k) =?= b_n
?m =?= fun (x_1 ... x_k), f (?m_1 x_1 ... x_k) ... (?m_n x_1 ... x_k)
?m =?= fun (x_1 ... x_k), g (?m_1 x_1 ... x_k) ... (?m_n x_1 ... x_k)
If f is a constant, then g is f.
If f is a local constant, then we consider each a_i that is equal to f.
Remark: we try to minimize the number of constraints (?m_i a_1 ... a_k) =?= b_i by detecting "easy" cases
that can be solved immediately. See \c is_easy_flex_rigid_arg
Remark: The term f is:
- g (if g is a constant), OR
- variable (if g is a local constant equal to a_i)
that can be solved immediately. See \c mk_imitiation_arg
*/
void mk_flex_rigid_app_cnstrs(expr const & f) {
void mk_app_imitation() {
lean_assert(is_metavar(m));
lean_assert(is_app(rhs));
lean_assert(is_constant(f) || is_var(f));
lean_assert(!u.m_tc[relax]->next_cnstr()); // make sure there are no pending constraints
buffer<expr> locals;
buffer<constraint> cs;
expr mtype = mlocal_type(m);
mtype = ensure_sufficient_args(mtype, cs);
buffer<expr> rargs;
get_app_args(rhs, rargs);
buffer<expr> sargs;
for (expr const & rarg : rargs) {
if (auto v = is_easy_flex_rigid_arg(rarg)) {
sargs.push_back(*v);
} else {
expr maux = mk_aux_metavar_for(u.m_ngen, mtype);
cs.push_back(mk_eq_cnstr(mk_app(maux, margs), rarg, j, relax));
sargs.push_back(mk_app_vars(maux, margs.size()));
flet<justification> let(j, j); // save j value
mk_local_context(locals, cs);
lean_assert(margs.size() == locals.size());
expr const & f = get_app_fn(rhs);
lean_assert(is_constant(f) || is_local(f));
if (is_local(f)) {
unsigned cs_sz = cs.size();
unsigned i = margs.size();
while (i > 0) {
--i;
if (is_local(margs[i]) && mlocal_name(margs[i]) == mlocal_name(f)) {
cs.shrink(cs_sz);
mk_app_imitation_core(locals[i], locals, cs);
}
}
}
expr v = mk_app(f, sargs);
v = mk_lambda_for(mtype, v);
if (check_imitation(v, cs)) {
cs.push_back(mk_eq_cnstr(m, v, j, relax));
alts.push_back(to_list(cs.begin(), cs.end()));
} else {
mk_app_imitation_core(f, locals, cs);
}
}
@ -1436,23 +1554,32 @@ struct unifier_fn {
void mk_bindings_imitation() {
lean_assert(is_metavar(m));
lean_assert(is_binding(rhs));
lean_assert(!u.m_tc[relax]->next_cnstr()); // make sure there are no pending constraints
buffer<constraint> cs;
expr mtype = mlocal_type(m);
mtype = ensure_sufficient_args(mtype, cs);
expr maux1 = mk_aux_metavar_for(u.m_ngen, mtype);
cs.push_back(mk_eq_cnstr(mk_app(maux1, margs), binding_domain(rhs), j, relax));
expr dontcare;
expr tmp_pi = mk_pi(binding_name(rhs), mk_app_vars(maux1, margs.size()), dontcare); // trick for "extending" the context
expr mtype2 = replace_range(mtype, tmp_pi); // trick for "extending" the context
expr maux2 = mk_aux_metavar_for(u.m_ngen, mtype2);
expr new_local = u.mk_local_for(rhs);
cs.push_back(mk_eq_cnstr(mk_app(mk_app(maux2, margs), new_local), instantiate(binding_body(rhs), new_local), j, relax));
expr v = update_binding(rhs, mk_app_vars(maux1, margs.size()), mk_app_vars(maux2, margs.size() + 1));
v = mk_lambda_for(mtype, v);
if (check_imitation(v, cs)) {
buffer<expr> locals;
flet<justification> let(j, j); // save j value
mk_local_context(locals, cs);
lean_assert(margs.size() == locals.size());
try {
// create a scope to make sure no constraints "leak" into the current state
type_checker::scope scope(*u.m_tc[relax]);
expr rhs_A = binding_domain(rhs);
expr A_type = u.m_tc[relax]->infer(rhs_A);
expr A = mk_imitiation_arg(rhs_A, A_type, locals, cs);
expr local = mk_local(u.m_ngen.next(), binding_name(rhs), A, binding_info(rhs));
locals.push_back(local);
margs.push_back(local);
expr rhs_B = instantiate(binding_body(rhs), local);
expr B_type = u.m_tc[relax]->infer(rhs_B);
expr B = mk_imitiation_arg(rhs_B, B_type, locals, cs);
expr binding = is_pi(rhs) ? Pi(local, B) : Fun(local, B);
locals.pop_back();
expr v = Fun(locals, binding);
copy_pending_constraints(cs);
cs.push_back(mk_eq_cnstr(m, v, j, relax));
alts.push_back(to_list(cs.begin(), cs.end()));
}
} catch (exception&) {}
margs.pop_back();
}
/**
@ -1514,27 +1641,11 @@ struct unifier_fn {
mk_simple_projections();
mk_macro_imitation();
break;
case expr_kind::App: {
expr const & f = get_app_fn(rhs);
if (is_local(f)) {
unsigned i = margs.size();
while (i > 0) {
unsigned vidx = margs.size() - i;
--i;
expr const & marg = margs[i];
if (is_local(marg) && mlocal_name(marg) == mlocal_name(f))
mk_flex_rigid_app_cnstrs(mk_var(vidx));
else if (!u.m_pattern)
mk_simple_nonlocal_projection(i);
}
} else {
lean_assert(is_constant(f));
if (!u.m_pattern)
mk_simple_projections();
mk_flex_rigid_app_cnstrs(f);
}
case expr_kind::App:
mk_app_projections();
mk_app_imitation();
break;
}}
}
}
};
@ -1600,6 +1711,14 @@ struct unifier_fn {
}
}
// std::cout << "FlexRigid\n";
// for (auto cs : alts) {
// std::cout << " alternative\n";
// for (auto c : cs) {
// std::cout << " >> " << c << "\n";
// }
// }
if (alts.empty()) {
set_conflict(j);
return false;
@ -1743,6 +1862,7 @@ struct unifier_fn {
if (!m_expensive && cidx >= get_group_first_index(cnstr_group::DelayedChoice2))
m_pattern = true; // use only higher-order (pattern) matching after we start processing MaxDelayed (aka class-instance constraints)
constraint c = p->first;
// std::cout << "process_next: " << c << "\n";
m_cnstrs.erase_min();
if (is_choice_cnstr(c)) {
return process_choice_constraint(c);