refactor(abstract_expr_manager): use get_specialization_prefix_size to improve performance of abstract_expr_manager

This commit is contained in:
Leonardo de Moura 2016-01-06 17:29:48 -08:00
parent d4a5aa6db0
commit 9fa1a7a01c
5 changed files with 56 additions and 44 deletions

View file

@ -32,23 +32,24 @@ unsigned abstract_expr_manager::hash(expr const & e) {
m_locals.pop_back(); m_locals.pop_back();
return h; return h;
case expr_kind::App: case expr_kind::App:
// TODO(Leo): in the past we only had to apply instantiate_rev to the function.
// We have to improve this.
// One idea is to compute and cache the specialization prefix for f.
// Then, we only need to apply instantiate_rev to f + prefix.
// expr f = instantiate_rev(get_app_args(e, args), m_locals.size(), m_locals.data());
expr new_e = instantiate_rev(e, m_locals.size(), m_locals.data());
optional<congr_lemma> f_congr = m_congr_lemma_manager.mk_specialized_congr(new_e);
buffer<expr> args; buffer<expr> args;
expr const & f = get_app_args(new_e, args); expr const & f = get_app_args(e, args);
h = hash(f); unsigned prefix_sz = m_congr_lemma_manager.get_specialization_prefix_size(f, args.size());
if (!f_congr) { expr new_f = e;
for (expr const & arg : args) { unsigned rest_sz = args.size() - prefix_sz;
h = ::lean::hash(h, hash(arg)); for (unsigned i = 0; i < rest_sz; i++)
new_f = app_fn(new_f);
new_f = instantiate_rev(new_f, m_locals.size(), m_locals.data());
optional<congr_lemma> congr = m_congr_lemma_manager.mk_congr(new_f, rest_sz);
h = hash(new_f);
if (!congr) {
for (unsigned i = prefix_sz; i < args.size(); i++) {
h = ::lean::hash(h, hash(args[i]));
} }
} else { } else {
unsigned i = 0; lean_assert(length(congr->get_arg_kinds()) == rest_sz);
for_each(f_congr->get_arg_kinds(), [&](congr_arg_kind const & c_kind) { unsigned i = prefix_sz;
for_each(congr->get_arg_kinds(), [&](congr_arg_kind const & c_kind) {
if (c_kind != congr_arg_kind::Cast) { if (c_kind != congr_arg_kind::Cast) {
h = ::lean::hash(h, hash(args[i])); h = ::lean::hash(h, hash(args[i]));
} }
@ -88,36 +89,40 @@ bool abstract_expr_manager::is_equal(expr const & a, expr const & b) {
} }
return true; return true;
case expr_kind::App: case expr_kind::App:
if (!is_equal(get_app_fn(a), get_app_fn(b))) {
return false;
}
if (get_app_num_args(a) != get_app_num_args(b)) {
return false;
}
// See comment in the hash function
expr new_a = instantiate_rev(a, m_locals.size(), m_locals.data());
expr new_b = instantiate_rev(b, m_locals.size(), m_locals.data());
optional<congr_lemma> congra = m_congr_lemma_manager.mk_specialized_congr(new_a);
optional<congr_lemma> congrb = m_congr_lemma_manager.mk_specialized_congr(new_b);
buffer<expr> a_args, b_args; buffer<expr> a_args, b_args;
get_app_args(new_a, a_args); expr const & f_a = get_app_args(a, a_args);
get_app_args(new_b, b_args); expr const & f_b = get_app_args(b, b_args);
if (!is_equal(f_a, f_b))
return false;
if (a_args.size() != b_args.size())
return false;
unsigned prefix_sz = m_congr_lemma_manager.get_specialization_prefix_size(f_a, a_args.size());
for (unsigned i = 0; i < prefix_sz; i++) {
if (!is_equal(a_args[i], b_args[i]))
return false;
}
expr new_f_a = a;
unsigned rest_sz = a_args.size() - prefix_sz;
for (unsigned i = 0; i < rest_sz; i++) {
new_f_a = app_fn(new_f_a);
}
new_f_a = instantiate_rev(new_f_a, m_locals.size(), m_locals.data());
optional<congr_lemma> congr = m_congr_lemma_manager.mk_congr(new_f_a, rest_sz);
bool not_equal = false; bool not_equal = false;
if (!congra || !congrb) { if (!congr) {
for (unsigned i = 0; i < a_args.size(); ++i) { for (unsigned i = prefix_sz; i < a_args.size(); ++i) {
if (!is_equal(a_args[i], b_args[i])) { if (!is_equal(a_args[i], b_args[i])) {
not_equal = true; not_equal = true;
break; break;
} }
} }
} else { } else {
unsigned i = 0; lean_assert(length(congr->get_arg_kinds()) == rest_sz);
for_each2(congra->get_arg_kinds(), unsigned i = prefix_sz;
congrb->get_arg_kinds(), for_each(congr->get_arg_kinds(), [&](congr_arg_kind const & c_kind) {
[&](congr_arg_kind const & ca_kind, congr_arg_kind const & cb_kind) {
if (not_equal) if (not_equal)
return; return;
if (ca_kind != cb_kind || (ca_kind != congr_arg_kind::Cast && !is_equal(a_args[i], b_args[i]))) { if (c_kind != congr_arg_kind::Cast && !is_equal(a_args[i], b_args[i])) {
not_equal = true; not_equal = true;
} }
i++; i++;

View file

@ -586,6 +586,10 @@ public:
optional<result> mk_rel_eq_congr(expr const & R) { optional<result> mk_rel_eq_congr(expr const & R) {
return mk_rel_congr(R, false); return mk_rel_congr(R, false);
} }
unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs) {
return m_fmanager.get_specialization_prefix_size(fn, nargs);
}
}; };
congr_lemma_manager::congr_lemma_manager(app_builder & b, fun_info_manager & fm): congr_lemma_manager::congr_lemma_manager(app_builder & b, fun_info_manager & fm):
@ -618,10 +622,12 @@ auto congr_lemma_manager::mk_specialized_congr(expr const & fn) -> optional<resu
auto congr_lemma_manager::mk_rel_iff_congr(expr const & R) -> optional<result> { auto congr_lemma_manager::mk_rel_iff_congr(expr const & R) -> optional<result> {
return m_ptr->mk_rel_iff_congr(R); return m_ptr->mk_rel_iff_congr(R);
} }
auto congr_lemma_manager::mk_rel_eq_congr(expr const & R) -> optional<result> { auto congr_lemma_manager::mk_rel_eq_congr(expr const & R) -> optional<result> {
return m_ptr->mk_rel_eq_congr(R); return m_ptr->mk_rel_eq_congr(R);
} }
unsigned congr_lemma_manager::get_specialization_prefix_size(expr const & fn, unsigned nargs) {
return m_ptr->get_specialization_prefix_size(fn, nargs);
}
void initialize_congr_lemma_manager() { void initialize_congr_lemma_manager() {
register_trace_class("congruence_manager"); register_trace_class("congruence_manager");

View file

@ -46,6 +46,7 @@ public:
typedef congr_lemma result; typedef congr_lemma result;
type_context & ctx(); type_context & ctx();
unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs);
optional<result> mk_congr_simp(expr const & fn); optional<result> mk_congr_simp(expr const & fn);
optional<result> mk_congr_simp(expr const & fn, unsigned nargs); optional<result> mk_congr_simp(expr const & fn, unsigned nargs);

View file

@ -164,7 +164,7 @@ void fun_info_manager::trace_if_unsupported(expr const & fn, buffer<expr> const
} }
} }
unsigned fun_info_manager::get_prefix(expr const & fn, unsigned nargs) { unsigned fun_info_manager::get_specialization_prefix_size(expr const & fn, unsigned nargs) {
/* /*
We say a function is "cheap" if it is of the form: We say a function is "cheap" if it is of the form:
@ -222,7 +222,7 @@ fun_info fun_info_manager::get_specialized(expr const & a) {
lean_assert(is_app(a)); lean_assert(is_app(a));
buffer<expr> args; buffer<expr> args;
expr const & fn = get_app_args(a, args); expr const & fn = get_app_args(a, args);
unsigned prefix_sz = get_prefix(fn, args.size()); unsigned prefix_sz = get_specialization_prefix_size(fn, args.size());
unsigned num_rest_args = args.size() - prefix_sz; unsigned num_rest_args = args.size() - prefix_sz;
expr g = a; expr g = a;
for (unsigned i = 0; i < num_rest_args; i++) for (unsigned i = 0; i < num_rest_args; i++)

View file

@ -122,7 +122,7 @@ public:
c) (inv : Pi {A : Type} [s : has_inv A] (x : A) (h : invertible x), A) c) (inv : Pi {A : Type} [s : has_inv A] (x : A) (h : invertible x), A)
result 2 result 2
*/ */
unsigned get_prefix(expr const & fn, unsigned nargs); unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs);
}; };
void initialize_fun_info_manager(); void initialize_fun_info_manager();