refactor(abstract_expr_manager): use get_specialization_prefix_size to improve performance of abstract_expr_manager
This commit is contained in:
parent
d4a5aa6db0
commit
9fa1a7a01c
5 changed files with 56 additions and 44 deletions
|
@ -32,23 +32,24 @@ unsigned abstract_expr_manager::hash(expr const & e) {
|
|||
m_locals.pop_back();
|
||||
return h;
|
||||
case expr_kind::App:
|
||||
// TODO(Leo): in the past we only had to apply instantiate_rev to the function.
|
||||
// We have to improve this.
|
||||
// One idea is to compute and cache the specialization prefix for f.
|
||||
// Then, we only need to apply instantiate_rev to f + prefix.
|
||||
// expr f = instantiate_rev(get_app_args(e, args), m_locals.size(), m_locals.data());
|
||||
expr new_e = instantiate_rev(e, m_locals.size(), m_locals.data());
|
||||
optional<congr_lemma> f_congr = m_congr_lemma_manager.mk_specialized_congr(new_e);
|
||||
buffer<expr> args;
|
||||
expr const & f = get_app_args(new_e, args);
|
||||
h = hash(f);
|
||||
if (!f_congr) {
|
||||
for (expr const & arg : args) {
|
||||
h = ::lean::hash(h, hash(arg));
|
||||
expr const & f = get_app_args(e, args);
|
||||
unsigned prefix_sz = m_congr_lemma_manager.get_specialization_prefix_size(f, args.size());
|
||||
expr new_f = e;
|
||||
unsigned rest_sz = args.size() - prefix_sz;
|
||||
for (unsigned i = 0; i < rest_sz; i++)
|
||||
new_f = app_fn(new_f);
|
||||
new_f = instantiate_rev(new_f, m_locals.size(), m_locals.data());
|
||||
optional<congr_lemma> congr = m_congr_lemma_manager.mk_congr(new_f, rest_sz);
|
||||
h = hash(new_f);
|
||||
if (!congr) {
|
||||
for (unsigned i = prefix_sz; i < args.size(); i++) {
|
||||
h = ::lean::hash(h, hash(args[i]));
|
||||
}
|
||||
} else {
|
||||
unsigned i = 0;
|
||||
for_each(f_congr->get_arg_kinds(), [&](congr_arg_kind const & c_kind) {
|
||||
lean_assert(length(congr->get_arg_kinds()) == rest_sz);
|
||||
unsigned i = prefix_sz;
|
||||
for_each(congr->get_arg_kinds(), [&](congr_arg_kind const & c_kind) {
|
||||
if (c_kind != congr_arg_kind::Cast) {
|
||||
h = ::lean::hash(h, hash(args[i]));
|
||||
}
|
||||
|
@ -88,40 +89,44 @@ bool abstract_expr_manager::is_equal(expr const & a, expr const & b) {
|
|||
}
|
||||
return true;
|
||||
case expr_kind::App:
|
||||
if (!is_equal(get_app_fn(a), get_app_fn(b))) {
|
||||
return false;
|
||||
}
|
||||
if (get_app_num_args(a) != get_app_num_args(b)) {
|
||||
return false;
|
||||
}
|
||||
// See comment in the hash function
|
||||
expr new_a = instantiate_rev(a, m_locals.size(), m_locals.data());
|
||||
expr new_b = instantiate_rev(b, m_locals.size(), m_locals.data());
|
||||
optional<congr_lemma> congra = m_congr_lemma_manager.mk_specialized_congr(new_a);
|
||||
optional<congr_lemma> congrb = m_congr_lemma_manager.mk_specialized_congr(new_b);
|
||||
buffer<expr> a_args, b_args;
|
||||
get_app_args(new_a, a_args);
|
||||
get_app_args(new_b, b_args);
|
||||
expr const & f_a = get_app_args(a, a_args);
|
||||
expr const & f_b = get_app_args(b, b_args);
|
||||
if (!is_equal(f_a, f_b))
|
||||
return false;
|
||||
if (a_args.size() != b_args.size())
|
||||
return false;
|
||||
unsigned prefix_sz = m_congr_lemma_manager.get_specialization_prefix_size(f_a, a_args.size());
|
||||
for (unsigned i = 0; i < prefix_sz; i++) {
|
||||
if (!is_equal(a_args[i], b_args[i]))
|
||||
return false;
|
||||
}
|
||||
expr new_f_a = a;
|
||||
unsigned rest_sz = a_args.size() - prefix_sz;
|
||||
for (unsigned i = 0; i < rest_sz; i++) {
|
||||
new_f_a = app_fn(new_f_a);
|
||||
}
|
||||
new_f_a = instantiate_rev(new_f_a, m_locals.size(), m_locals.data());
|
||||
optional<congr_lemma> congr = m_congr_lemma_manager.mk_congr(new_f_a, rest_sz);
|
||||
bool not_equal = false;
|
||||
if (!congra || !congrb) {
|
||||
for (unsigned i = 0; i < a_args.size(); ++i) {
|
||||
if (!congr) {
|
||||
for (unsigned i = prefix_sz; i < a_args.size(); ++i) {
|
||||
if (!is_equal(a_args[i], b_args[i])) {
|
||||
not_equal = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
unsigned i = 0;
|
||||
for_each2(congra->get_arg_kinds(),
|
||||
congrb->get_arg_kinds(),
|
||||
[&](congr_arg_kind const & ca_kind, congr_arg_kind const & cb_kind) {
|
||||
if (not_equal)
|
||||
return;
|
||||
if (ca_kind != cb_kind || (ca_kind != congr_arg_kind::Cast && !is_equal(a_args[i], b_args[i]))) {
|
||||
not_equal = true;
|
||||
}
|
||||
i++;
|
||||
});
|
||||
lean_assert(length(congr->get_arg_kinds()) == rest_sz);
|
||||
unsigned i = prefix_sz;
|
||||
for_each(congr->get_arg_kinds(), [&](congr_arg_kind const & c_kind) {
|
||||
if (not_equal)
|
||||
return;
|
||||
if (c_kind != congr_arg_kind::Cast && !is_equal(a_args[i], b_args[i])) {
|
||||
not_equal = true;
|
||||
}
|
||||
i++;
|
||||
});
|
||||
}
|
||||
return !not_equal;
|
||||
}
|
||||
|
|
|
@ -586,6 +586,10 @@ public:
|
|||
optional<result> mk_rel_eq_congr(expr const & R) {
|
||||
return mk_rel_congr(R, false);
|
||||
}
|
||||
|
||||
unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs) {
|
||||
return m_fmanager.get_specialization_prefix_size(fn, nargs);
|
||||
}
|
||||
};
|
||||
|
||||
congr_lemma_manager::congr_lemma_manager(app_builder & b, fun_info_manager & fm):
|
||||
|
@ -618,10 +622,12 @@ auto congr_lemma_manager::mk_specialized_congr(expr const & fn) -> optional<resu
|
|||
auto congr_lemma_manager::mk_rel_iff_congr(expr const & R) -> optional<result> {
|
||||
return m_ptr->mk_rel_iff_congr(R);
|
||||
}
|
||||
|
||||
auto congr_lemma_manager::mk_rel_eq_congr(expr const & R) -> optional<result> {
|
||||
return m_ptr->mk_rel_eq_congr(R);
|
||||
}
|
||||
unsigned congr_lemma_manager::get_specialization_prefix_size(expr const & fn, unsigned nargs) {
|
||||
return m_ptr->get_specialization_prefix_size(fn, nargs);
|
||||
}
|
||||
|
||||
void initialize_congr_lemma_manager() {
|
||||
register_trace_class("congruence_manager");
|
||||
|
|
|
@ -46,6 +46,7 @@ public:
|
|||
typedef congr_lemma result;
|
||||
|
||||
type_context & ctx();
|
||||
unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs);
|
||||
|
||||
optional<result> mk_congr_simp(expr const & fn);
|
||||
optional<result> mk_congr_simp(expr const & fn, unsigned nargs);
|
||||
|
|
|
@ -164,7 +164,7 @@ void fun_info_manager::trace_if_unsupported(expr const & fn, buffer<expr> const
|
|||
}
|
||||
}
|
||||
|
||||
unsigned fun_info_manager::get_prefix(expr const & fn, unsigned nargs) {
|
||||
unsigned fun_info_manager::get_specialization_prefix_size(expr const & fn, unsigned nargs) {
|
||||
/*
|
||||
We say a function is "cheap" if it is of the form:
|
||||
|
||||
|
@ -222,7 +222,7 @@ fun_info fun_info_manager::get_specialized(expr const & a) {
|
|||
lean_assert(is_app(a));
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(a, args);
|
||||
unsigned prefix_sz = get_prefix(fn, args.size());
|
||||
unsigned prefix_sz = get_specialization_prefix_size(fn, args.size());
|
||||
unsigned num_rest_args = args.size() - prefix_sz;
|
||||
expr g = a;
|
||||
for (unsigned i = 0; i < num_rest_args; i++)
|
||||
|
|
|
@ -122,7 +122,7 @@ public:
|
|||
c) (inv : Pi {A : Type} [s : has_inv A] (x : A) (h : invertible x), A)
|
||||
result 2
|
||||
*/
|
||||
unsigned get_prefix(expr const & fn, unsigned nargs);
|
||||
unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs);
|
||||
};
|
||||
|
||||
void initialize_fun_info_manager();
|
||||
|
|
Loading…
Reference in a new issue