refactor(abstract_expr_manager): use get_specialization_prefix_size to improve performance of abstract_expr_manager

This commit is contained in:
Leonardo de Moura 2016-01-06 17:29:48 -08:00
parent d4a5aa6db0
commit 9fa1a7a01c
5 changed files with 56 additions and 44 deletions

View file

@ -32,23 +32,24 @@ unsigned abstract_expr_manager::hash(expr const & e) {
m_locals.pop_back();
return h;
case expr_kind::App:
// TODO(Leo): in the past we only had to apply instantiate_rev to the function.
// We have to improve this.
// One idea is to compute and cache the specialization prefix for f.
// Then, we only need to apply instantiate_rev to f + prefix.
// expr f = instantiate_rev(get_app_args(e, args), m_locals.size(), m_locals.data());
expr new_e = instantiate_rev(e, m_locals.size(), m_locals.data());
optional<congr_lemma> f_congr = m_congr_lemma_manager.mk_specialized_congr(new_e);
buffer<expr> args;
expr const & f = get_app_args(new_e, args);
h = hash(f);
if (!f_congr) {
for (expr const & arg : args) {
h = ::lean::hash(h, hash(arg));
expr const & f = get_app_args(e, args);
unsigned prefix_sz = m_congr_lemma_manager.get_specialization_prefix_size(f, args.size());
expr new_f = e;
unsigned rest_sz = args.size() - prefix_sz;
for (unsigned i = 0; i < rest_sz; i++)
new_f = app_fn(new_f);
new_f = instantiate_rev(new_f, m_locals.size(), m_locals.data());
optional<congr_lemma> congr = m_congr_lemma_manager.mk_congr(new_f, rest_sz);
h = hash(new_f);
if (!congr) {
for (unsigned i = prefix_sz; i < args.size(); i++) {
h = ::lean::hash(h, hash(args[i]));
}
} else {
unsigned i = 0;
for_each(f_congr->get_arg_kinds(), [&](congr_arg_kind const & c_kind) {
lean_assert(length(congr->get_arg_kinds()) == rest_sz);
unsigned i = prefix_sz;
for_each(congr->get_arg_kinds(), [&](congr_arg_kind const & c_kind) {
if (c_kind != congr_arg_kind::Cast) {
h = ::lean::hash(h, hash(args[i]));
}
@ -88,40 +89,44 @@ bool abstract_expr_manager::is_equal(expr const & a, expr const & b) {
}
return true;
case expr_kind::App:
if (!is_equal(get_app_fn(a), get_app_fn(b))) {
return false;
}
if (get_app_num_args(a) != get_app_num_args(b)) {
return false;
}
// See comment in the hash function
expr new_a = instantiate_rev(a, m_locals.size(), m_locals.data());
expr new_b = instantiate_rev(b, m_locals.size(), m_locals.data());
optional<congr_lemma> congra = m_congr_lemma_manager.mk_specialized_congr(new_a);
optional<congr_lemma> congrb = m_congr_lemma_manager.mk_specialized_congr(new_b);
buffer<expr> a_args, b_args;
get_app_args(new_a, a_args);
get_app_args(new_b, b_args);
expr const & f_a = get_app_args(a, a_args);
expr const & f_b = get_app_args(b, b_args);
if (!is_equal(f_a, f_b))
return false;
if (a_args.size() != b_args.size())
return false;
unsigned prefix_sz = m_congr_lemma_manager.get_specialization_prefix_size(f_a, a_args.size());
for (unsigned i = 0; i < prefix_sz; i++) {
if (!is_equal(a_args[i], b_args[i]))
return false;
}
expr new_f_a = a;
unsigned rest_sz = a_args.size() - prefix_sz;
for (unsigned i = 0; i < rest_sz; i++) {
new_f_a = app_fn(new_f_a);
}
new_f_a = instantiate_rev(new_f_a, m_locals.size(), m_locals.data());
optional<congr_lemma> congr = m_congr_lemma_manager.mk_congr(new_f_a, rest_sz);
bool not_equal = false;
if (!congra || !congrb) {
for (unsigned i = 0; i < a_args.size(); ++i) {
if (!congr) {
for (unsigned i = prefix_sz; i < a_args.size(); ++i) {
if (!is_equal(a_args[i], b_args[i])) {
not_equal = true;
break;
}
}
} else {
unsigned i = 0;
for_each2(congra->get_arg_kinds(),
congrb->get_arg_kinds(),
[&](congr_arg_kind const & ca_kind, congr_arg_kind const & cb_kind) {
if (not_equal)
return;
if (ca_kind != cb_kind || (ca_kind != congr_arg_kind::Cast && !is_equal(a_args[i], b_args[i]))) {
not_equal = true;
}
i++;
});
lean_assert(length(congr->get_arg_kinds()) == rest_sz);
unsigned i = prefix_sz;
for_each(congr->get_arg_kinds(), [&](congr_arg_kind const & c_kind) {
if (not_equal)
return;
if (c_kind != congr_arg_kind::Cast && !is_equal(a_args[i], b_args[i])) {
not_equal = true;
}
i++;
});
}
return !not_equal;
}

View file

@ -586,6 +586,10 @@ public:
optional<result> mk_rel_eq_congr(expr const & R) {
return mk_rel_congr(R, false);
}
unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs) {
return m_fmanager.get_specialization_prefix_size(fn, nargs);
}
};
congr_lemma_manager::congr_lemma_manager(app_builder & b, fun_info_manager & fm):
@ -618,10 +622,12 @@ auto congr_lemma_manager::mk_specialized_congr(expr const & fn) -> optional<resu
auto congr_lemma_manager::mk_rel_iff_congr(expr const & R) -> optional<result> {
return m_ptr->mk_rel_iff_congr(R);
}
auto congr_lemma_manager::mk_rel_eq_congr(expr const & R) -> optional<result> {
return m_ptr->mk_rel_eq_congr(R);
}
unsigned congr_lemma_manager::get_specialization_prefix_size(expr const & fn, unsigned nargs) {
return m_ptr->get_specialization_prefix_size(fn, nargs);
}
void initialize_congr_lemma_manager() {
register_trace_class("congruence_manager");

View file

@ -46,6 +46,7 @@ public:
typedef congr_lemma result;
type_context & ctx();
unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs);
optional<result> mk_congr_simp(expr const & fn);
optional<result> mk_congr_simp(expr const & fn, unsigned nargs);

View file

@ -164,7 +164,7 @@ void fun_info_manager::trace_if_unsupported(expr const & fn, buffer<expr> const
}
}
unsigned fun_info_manager::get_prefix(expr const & fn, unsigned nargs) {
unsigned fun_info_manager::get_specialization_prefix_size(expr const & fn, unsigned nargs) {
/*
We say a function is "cheap" if it is of the form:
@ -222,7 +222,7 @@ fun_info fun_info_manager::get_specialized(expr const & a) {
lean_assert(is_app(a));
buffer<expr> args;
expr const & fn = get_app_args(a, args);
unsigned prefix_sz = get_prefix(fn, args.size());
unsigned prefix_sz = get_specialization_prefix_size(fn, args.size());
unsigned num_rest_args = args.size() - prefix_sz;
expr g = a;
for (unsigned i = 0; i < num_rest_args; i++)

View file

@ -122,7 +122,7 @@ public:
c) (inv : Pi {A : Type} [s : has_inv A] (x : A) (h : invertible x), A)
result 2
*/
unsigned get_prefix(expr const & fn, unsigned nargs);
unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs);
};
void initialize_fun_info_manager();