feat(library/fun_info_manager,library/congr_lemma_manager,blast/simplifier): specialized congruence lemmas

We still need a lot of polishing.
This commit is contained in:
Leonardo de Moura 2016-01-06 17:29:35 -08:00
parent 930fcddace
commit f3b8aef24c
10 changed files with 353 additions and 149 deletions

View file

@ -32,18 +32,27 @@ unsigned abstract_expr_manager::hash(expr const & e) {
m_locals.pop_back(); m_locals.pop_back();
return h; return h;
case expr_kind::App: case expr_kind::App:
// TODO(Leo): in the past we only had to apply instantiate_rev to the function.
// We have to improve this.
// One idea is to compute and cache the specialization prefix for f.
// Then, we only need to apply instantiate_rev to f + prefix.
// expr f = instantiate_rev(get_app_args(e, args), m_locals.size(), m_locals.data());
expr new_e = instantiate_rev(e, m_locals.size(), m_locals.data());
optional<congr_lemma> f_congr = m_congr_lemma_manager.mk_specialized_congr(new_e);
buffer<expr> args; buffer<expr> args;
expr f = instantiate_rev(get_app_args(e, args), m_locals.size(), m_locals.data()); expr const & f = get_app_args(new_e, args);
optional<congr_lemma> f_congr = m_congr_lemma_manager.mk_congr(f, args.size());
h = hash(f); h = hash(f);
if (!f_congr) { if (!f_congr) {
for (expr const & arg : args) h = ::lean::hash(h, hash(arg)); for (expr const & arg : args) {
h = ::lean::hash(h, hash(arg));
}
} else { } else {
int i = -1; unsigned i = 0;
for_each(f_congr->get_arg_kinds(), [&](congr_arg_kind const & c_kind) { for_each(f_congr->get_arg_kinds(), [&](congr_arg_kind const & c_kind) {
if (c_kind != congr_arg_kind::Cast) {
h = ::lean::hash(h, hash(args[i]));
}
i++; i++;
if (c_kind == congr_arg_kind::Cast) return;
h = ::lean::hash(h, hash(args[i]));
}); });
} }
return h; return h;
@ -52,61 +61,70 @@ unsigned abstract_expr_manager::hash(expr const & e) {
} }
bool abstract_expr_manager::is_equal(expr const & a, expr const & b) { bool abstract_expr_manager::is_equal(expr const & a, expr const & b) {
if (is_eqp(a, b)) return true; if (is_eqp(a, b)) return true;
if (hash(a) != hash(b)) return false; if (a.kind() != b.kind()) return false;
if (a.kind() != b.kind()) return false; if (is_var(a)) return var_idx(a) == var_idx(b);
if (is_var(a)) return var_idx(a) == var_idx(b); bool is_eq;
bool is_eq; switch (a.kind()) {
switch (a.kind()) { case expr_kind::Var:
case expr_kind::Var:
lean_unreachable(); // LCOV_EXCL_LINE
case expr_kind::Constant: case expr_kind::Sort:
return a == b;
case expr_kind::Meta: case expr_kind::Local:
return mlocal_name(a) == mlocal_name(b) && is_equal(mlocal_type(a), mlocal_type(b));
case expr_kind::Lambda: case expr_kind::Pi:
if (!is_equal(binding_domain(a), binding_domain(b))) return false;
// see comment at abstract_expr_manager::hash
m_locals.push_back(instantiate_rev(m_tctx.mk_tmp_local(binding_domain(a)), m_locals.size(), m_locals.data()));
is_eq = is_equal(binding_body(a), binding_body(b));
m_locals.pop_back();
return is_eq;
case expr_kind::Macro:
if (macro_def(a) != macro_def(b) || macro_num_args(a) != macro_num_args(b))
return false;
for (unsigned i = 0; i < macro_num_args(a); i++) {
if (!is_equal(macro_arg(a, i), macro_arg(b, i)))
return false;
}
return true;
case expr_kind::App:
buffer<expr> a_args, b_args;
expr f_a = get_app_args(a, a_args);
expr f_b = get_app_args(b, b_args);
if (!is_equal(f_a, f_b)) return false;
if (a_args.size() != b_args.size()) return false;
expr f = instantiate_rev(f_a, m_locals.size(), m_locals.data());
optional<congr_lemma> f_congr = m_congr_lemma_manager.mk_congr(f, a_args.size());
bool not_equal = false;
if (!f_congr) {
for (unsigned i = 0; i < a_args.size(); ++i) {
if (!is_equal(a_args[i], b_args[i])) {
not_equal = true;
break;
}
}
} else {
int i = -1;
for_each(f_congr->get_arg_kinds(), [&](congr_arg_kind const & c_kind) {
if (not_equal) return;
i++;
if (c_kind == congr_arg_kind::Cast) return;
if (!is_equal(a_args[i], b_args[i])) not_equal = true;
});
}
return !not_equal;
}
lean_unreachable(); // LCOV_EXCL_LINE lean_unreachable(); // LCOV_EXCL_LINE
case expr_kind::Constant: case expr_kind::Sort:
return a == b;
case expr_kind::Meta: case expr_kind::Local:
return mlocal_name(a) == mlocal_name(b) && is_equal(mlocal_type(a), mlocal_type(b));
case expr_kind::Lambda: case expr_kind::Pi:
if (!is_equal(binding_domain(a), binding_domain(b))) return false;
// see comment at abstract_expr_manager::hash
m_locals.push_back(instantiate_rev(m_tctx.mk_tmp_local(binding_domain(a)), m_locals.size(), m_locals.data()));
is_eq = is_equal(binding_body(a), binding_body(b));
m_locals.pop_back();
return is_eq;
case expr_kind::Macro:
if (macro_def(a) != macro_def(b) || macro_num_args(a) != macro_num_args(b))
return false;
for (unsigned i = 0; i < macro_num_args(a); i++) {
if (!is_equal(macro_arg(a, i), macro_arg(b, i)))
return false;
}
return true;
case expr_kind::App:
if (!is_equal(get_app_fn(a), get_app_fn(b))) {
return false;
}
if (get_app_num_args(a) != get_app_num_args(b)) {
return false;
}
// See comment in the hash function
expr new_a = instantiate_rev(a, m_locals.size(), m_locals.data());
expr new_b = instantiate_rev(b, m_locals.size(), m_locals.data());
optional<congr_lemma> congra = m_congr_lemma_manager.mk_specialized_congr(new_a);
optional<congr_lemma> congrb = m_congr_lemma_manager.mk_specialized_congr(new_b);
buffer<expr> a_args, b_args;
get_app_args(new_a, a_args);
get_app_args(new_b, b_args);
bool not_equal = false;
if (!congra || !congrb) {
for (unsigned i = 0; i < a_args.size(); ++i) {
if (!is_equal(a_args[i], b_args[i])) {
not_equal = true;
break;
}
}
} else {
unsigned i = 0;
for_each2(congra->get_arg_kinds(),
congrb->get_arg_kinds(),
[&](congr_arg_kind const & ca_kind, congr_arg_kind const & cb_kind) {
if (not_equal)
return;
if (ca_kind != cb_kind || (ca_kind != congr_arg_kind::Cast && !is_equal(a_args[i], b_args[i]))) {
not_equal = true;
}
i++;
});
}
return !not_equal;
}
lean_unreachable(); // LCOV_EXCL_LINE
} }
} }

View file

@ -719,6 +719,10 @@ public:
return m_congr_lemma_manager.mk_congr_simp(fn); return m_congr_lemma_manager.mk_congr_simp(fn);
} }
optional<congr_lemma> mk_specialized_congr_lemma_for_simp(expr const & fn) {
return m_congr_lemma_manager.mk_specialized_congr_simp(fn);
}
optional<congr_lemma> mk_congr_lemma(expr const & fn, unsigned num_args) { optional<congr_lemma> mk_congr_lemma(expr const & fn, unsigned num_args) {
return m_congr_lemma_manager.mk_congr(fn, num_args); return m_congr_lemma_manager.mk_congr(fn, num_args);
} }
@ -727,6 +731,10 @@ public:
return m_congr_lemma_manager.mk_congr(fn); return m_congr_lemma_manager.mk_congr(fn);
} }
optional<congr_lemma> mk_specialized_congr_lemma(expr const & a) {
return m_congr_lemma_manager.mk_specialized_congr(a);
}
optional<congr_lemma> mk_rel_iff_congr(expr const & fn) { optional<congr_lemma> mk_rel_iff_congr(expr const & fn) {
return m_congr_lemma_manager.mk_rel_iff_congr(fn); return m_congr_lemma_manager.mk_rel_iff_congr(fn);
} }
@ -743,6 +751,10 @@ public:
return m_fun_info_manager.get(fn, nargs); return m_fun_info_manager.get(fn, nargs);
} }
fun_info get_specialized_fun_info(expr const & a) {
return m_fun_info_manager.get_specialized(a);
}
unsigned abstract_hash(expr const & e) { unsigned abstract_hash(expr const & e) {
return m_abstract_expr_manager.hash(e); return m_abstract_expr_manager.hash(e);
} }
@ -1083,6 +1095,11 @@ optional<congr_lemma> mk_congr_lemma_for_simp(expr const & fn) {
return g_blastenv->mk_congr_lemma_for_simp(fn); return g_blastenv->mk_congr_lemma_for_simp(fn);
} }
optional<congr_lemma> mk_specialized_congr_lemma_for_simp(expr const & a) {
lean_assert(g_blastenv);
return g_blastenv->mk_specialized_congr_lemma_for_simp(a);
}
optional<congr_lemma> mk_congr_lemma(expr const & fn, unsigned num_args) { optional<congr_lemma> mk_congr_lemma(expr const & fn, unsigned num_args) {
lean_assert(g_blastenv); lean_assert(g_blastenv);
return g_blastenv->mk_congr_lemma(fn, num_args); return g_blastenv->mk_congr_lemma(fn, num_args);
@ -1093,6 +1110,11 @@ optional<congr_lemma> mk_congr_lemma(expr const & fn) {
return g_blastenv->mk_congr_lemma(fn); return g_blastenv->mk_congr_lemma(fn);
} }
optional<congr_lemma> mk_specialized_congr_lemma(expr const & a) {
lean_assert(g_blastenv);
return g_blastenv->mk_specialized_congr_lemma(a);
}
optional<congr_lemma> mk_rel_iff_congr(expr const & fn) { optional<congr_lemma> mk_rel_iff_congr(expr const & fn) {
lean_assert(g_blastenv); lean_assert(g_blastenv);
return g_blastenv->mk_rel_iff_congr(fn); return g_blastenv->mk_rel_iff_congr(fn);
@ -1113,6 +1135,11 @@ fun_info get_fun_info(expr const & fn, unsigned nargs) {
return g_blastenv->get_fun_info(fn, nargs); return g_blastenv->get_fun_info(fn, nargs);
} }
fun_info get_specialized_fun_info(expr const & a) {
lean_assert(g_blastenv);
return g_blastenv->get_specialized_fun_info(a);
}
unsigned abstract_hash(expr const & e) { unsigned abstract_hash(expr const & e) {
lean_assert(g_blastenv); lean_assert(g_blastenv);
return g_blastenv->abstract_hash(e); return g_blastenv->abstract_hash(e);
@ -1326,11 +1353,13 @@ void initialize_blast() {
register_trace_class_alias("app_builder", name({"blast", "event"})); register_trace_class_alias("app_builder", name({"blast", "event"}));
register_trace_class_alias(name({"simplifier", "failure"}), name({"blast", "event"})); register_trace_class_alias(name({"simplifier", "failure"}), name({"blast", "event"}));
register_trace_class_alias("fun_info", name({"blast", "event"}));
register_trace_class_alias(name({"cc", "propagation"}), "blast"); register_trace_class_alias(name({"cc", "propagation"}), "blast");
register_trace_class_alias("blast", "blast_detailed"); register_trace_class_alias("blast", "blast_detailed");
register_trace_class_alias("app_builder", "blast_detailed"); register_trace_class_alias("app_builder", "blast_detailed");
register_trace_class_alias("fun_info", "blast_detailed");
register_trace_class_alias(name({"simplifier", "failure"}), "blast_detailed"); register_trace_class_alias(name({"simplifier", "failure"}), "blast_detailed");
register_trace_class_alias(name({"cc", "merge"}), "blast_detailed"); register_trace_class_alias(name({"cc", "merge"}), "blast_detailed");

View file

@ -126,6 +126,10 @@ optional<expr> mk_class_instance(expr const & e);
optional<congr_lemma> mk_congr_lemma_for_simp(expr const & fn, unsigned num_args); optional<congr_lemma> mk_congr_lemma_for_simp(expr const & fn, unsigned num_args);
/** \brief Similar to previous procedure, but num_args == arith of fn */ /** \brief Similar to previous procedure, but num_args == arith of fn */
optional<congr_lemma> mk_congr_lemma_for_simp(expr const & fn); optional<congr_lemma> mk_congr_lemma_for_simp(expr const & fn);
/** \brief Similar to previous procedures, but \c a is a function application,
and the arguments are taken into account when computing the lemma.
\pre is_app(a) */
optional<congr_lemma> mk_specialized_congr_lemma_for_simp(expr const & a);
/** \brief Create a congruence lemma for the given function. /** \brief Create a congruence lemma for the given function.
\pre num_args <= arity of fn \pre num_args <= arity of fn
@ -140,6 +144,10 @@ optional<congr_lemma> mk_congr_lemma_for_simp(expr const & fn);
optional<congr_lemma> mk_congr_lemma(expr const & fn, unsigned num_args); optional<congr_lemma> mk_congr_lemma(expr const & fn, unsigned num_args);
/** \brief Similar to previous procedure, but num_args == arith of fn */ /** \brief Similar to previous procedure, but num_args == arith of fn */
optional<congr_lemma> mk_congr_lemma(expr const & fn); optional<congr_lemma> mk_congr_lemma(expr const & fn);
/** \brief Similar to previous procedures, but \c a is a function application,
and the arguments are taken into account when computing the lemma.
\pre is_app(a) */
optional<congr_lemma> mk_specialized_congr_lemma(expr const & a);
optional<congr_lemma> mk_rel_iff_congr(expr const & fn); optional<congr_lemma> mk_rel_iff_congr(expr const & fn);
optional<congr_lemma> mk_rel_eq_congr(expr const & fn); optional<congr_lemma> mk_rel_eq_congr(expr const & fn);
@ -148,6 +156,10 @@ fun_info get_fun_info(expr const & fn);
/** \brief Retrieve information for the given function. /** \brief Retrieve information for the given function.
\pre nargs <= arity fn. */ \pre nargs <= arity fn. */
fun_info get_fun_info(expr const & fn, unsigned nargs); fun_info get_fun_info(expr const & fn, unsigned nargs);
/** \brief Retrieve information for the given function-application
taking into account the actual arguments.
\pre is_app(a) */
fun_info get_specialized_fun_info(expr const & a);
/** \brief Hash and equality test for abstract expressions */ /** \brief Hash and equality test for abstract expressions */
unsigned abstract_hash(expr const & e); unsigned abstract_hash(expr const & e);

View file

@ -267,17 +267,17 @@ optional<result> simplifier::cache_lookup(expr const & e) {
if (e == e_old) return optional<result>(r_old); if (e == e_old) return optional<result>(r_old);
lean_assert(is_app(e_old)); lean_assert(is_app(e_old));
buffer<expr> new_args, old_args; buffer<expr> new_args, old_args;
expr const & f_new = get_app_args(e, new_args); DEBUG_CODE(expr const & f_new =) get_app_args(e, new_args);
DEBUG_CODE(expr const & f_old =) get_app_args(e_old, old_args); DEBUG_CODE(expr const & f_old =) get_app_args(e_old, old_args);
lean_assert(f_new == f_old); lean_assert(f_new == f_old);
auto congr_lemma = mk_congr_lemma(f_new, new_args.size()); auto congr_lemma = mk_specialized_congr_lemma(e);
if (!congr_lemma) return optional<result>(); if (!congr_lemma) return optional<result>();
expr proof = congr_lemma->get_proof(); expr proof = congr_lemma->get_proof();
expr type = congr_lemma->get_type(); expr type = congr_lemma->get_type();
unsigned i = 0; unsigned i = 0;
for_each(congr_lemma->get_arg_kinds(), [&](congr_arg_kind const & ckind) { for_each(congr_lemma->get_arg_kinds(), [&](congr_arg_kind const & ckind) {
lean_assert(ckind == congr_arg_kind::Cast || new_args[i] == old_args[i]); lean_assert(ckind == congr_arg_kind::Cast || new_args[i] == old_args[i], static_cast<unsigned>(ckind), new_args[i], old_args[i]);
expr rfl; expr rfl;
switch (ckind) { switch (ckind) {
case congr_arg_kind::Fixed: case congr_arg_kind::Fixed:
@ -469,7 +469,7 @@ result simplifier::simplify_pi(expr const & e) {
expr simplifier::unfold_reducible_instances(expr const & e) { expr simplifier::unfold_reducible_instances(expr const & e) {
buffer<expr> args; buffer<expr> args;
expr f = get_app_args(e, args); expr f = get_app_args(e, args);
fun_info f_info = get_fun_info(f, args.size()); fun_info f_info = get_specialized_fun_info(e);
int i = -1; int i = -1;
for_each(f_info.get_params_info(), [&](param_info const & p_info) { for_each(f_info.get_params_info(), [&](param_info const & p_info) {
i++; i++;
@ -802,7 +802,7 @@ optional<result> simplifier::synth_congr(expr const & e, F && simp) {
lean_assert(is_app(e)); lean_assert(is_app(e));
buffer<expr> args; buffer<expr> args;
expr f = get_app_args(e, args); expr f = get_app_args(e, args);
auto congr_lemma = mk_congr_lemma_for_simp(f, args.size()); auto congr_lemma = mk_specialized_congr_lemma_for_simp(e);
if (!congr_lemma) return optional<result>(); if (!congr_lemma) return optional<result>();
expr proof = congr_lemma->get_proof(); expr proof = congr_lemma->get_proof();
expr type = congr_lemma->get_type(); expr type = congr_lemma->get_type();
@ -849,7 +849,7 @@ optional<result> simplifier::synth_congr(expr const & e, F && simp) {
expr simplifier::remove_unnecessary_casts(expr const & e) { expr simplifier::remove_unnecessary_casts(expr const & e) {
buffer<expr> args; buffer<expr> args;
expr f = get_app_args(e, args); expr f = get_app_args(e, args);
fun_info f_info = get_fun_info(f, args.size()); fun_info f_info = get_specialized_fun_info(e);
int i = -1; int i = -1;
for_each(f_info.get_params_info(), [&](param_info const & p_info) { for_each(f_info.get_params_info(), [&](param_info const & p_info) {
i++; i++;

View file

@ -28,7 +28,9 @@ struct congr_lemma_manager::imp {
typedef expr_unsigned key; typedef expr_unsigned key;
typedef expr_unsigned_map<result> cache; typedef expr_unsigned_map<result> cache;
cache m_simp_cache; cache m_simp_cache;
cache m_simp_cache_spec;
cache m_cache; cache m_cache;
cache m_cache_spec;
cache m_rel_cache[2]; cache m_rel_cache[2];
relation_info_getter m_relation_info_getter; relation_info_getter m_relation_info_getter;
@ -298,7 +300,7 @@ struct congr_lemma_manager::imp {
return optional<result>(congr_type, congr_proof, to_list(kinds)); return optional<result>(congr_type, congr_proof, to_list(kinds));
} }
buffer<expr> rhss1; buffer<expr> rhss1;
get_app_args(rhs1, rhss1); get_app_args_at_most(rhs1, rhss.size(), rhss1);
lean_assert(rhss.size() == rhss1.size()); lean_assert(rhss.size() == rhss1.size());
expr a = mk_app(fn, i, rhss1.data()); expr a = mk_app(fn, i, rhss1.data());
expr pr2 = m_builder.mk_eq_refl(a); expr pr2 = m_builder.mk_eq_refl(a);
@ -321,18 +323,10 @@ struct congr_lemma_manager::imp {
} }
} }
public: optional<result> mk_congr_simp(expr const & fn, unsigned nargs, fun_info const & finfo) {
imp(app_builder & b, fun_info_manager & fm):
m_builder(b), m_fmanager(fm), m_ctx(fm.ctx()),
m_relation_info_getter(mk_relation_info_getter(fm.ctx().env())) {}
type_context & ctx() { return m_ctx; }
optional<result> mk_congr_simp(expr const & fn, unsigned nargs) {
auto r = m_simp_cache.find(key(fn, nargs)); auto r = m_simp_cache.find(key(fn, nargs));
if (r != m_simp_cache.end()) if (r != m_simp_cache.end())
return optional<result>(r->second); return optional<result>(r->second);
fun_info finfo = m_fmanager.get(fn, nargs);
list<unsigned> const & result_deps = finfo.get_result_dependencies(); list<unsigned> const & result_deps = finfo.get_result_dependencies();
buffer<congr_arg_kind> kinds; buffer<congr_arg_kind> kinds;
buffer<param_info> pinfos; buffer<param_info> pinfos;
@ -386,16 +380,10 @@ public:
} }
} }
optional<result> mk_congr_simp(expr const & fn) { optional<result> mk_congr(expr const & fn, unsigned nargs, fun_info const & finfo) {
fun_info finfo = m_fmanager.get(fn);
return mk_congr_simp(fn, finfo.get_arity());
}
optional<result> mk_congr(expr const & fn, unsigned nargs) {
auto r = m_cache.find(key(fn, nargs)); auto r = m_cache.find(key(fn, nargs));
if (r != m_cache.end()) if (r != m_cache.end())
return optional<result>(r->second); return optional<result>(r->second);
fun_info finfo = m_fmanager.get(fn, nargs);
optional<result> simp_lemma = mk_congr_simp(fn, nargs); optional<result> simp_lemma = mk_congr_simp(fn, nargs);
if (!simp_lemma) if (!simp_lemma)
return optional<result>(); return optional<result>();
@ -427,9 +415,86 @@ public:
return new_r; return new_r;
} }
void pre_specialize(expr const & a, expr & g, unsigned & prefix_sz, unsigned & num_rest_args) {
fun_info finfo = m_fmanager.get_specialized(a);
prefix_sz = 0;
for (param_info const & pinfo : finfo.get_params_info()) {
if (!pinfo.specialized())
break;
prefix_sz++;
}
num_rest_args = finfo.get_arity() - prefix_sz;
g = a;
for (unsigned i = 0; i < num_rest_args; i++) {
g = app_fn(g);
}
}
result mk_specialize_result(result const & r, unsigned prefix_sz) {
list<congr_arg_kind> new_arg_kinds = r.get_arg_kinds();
for (unsigned i = 0; i < prefix_sz; i++)
new_arg_kinds = cons(congr_arg_kind::FixedNoParam, new_arg_kinds);
return result(r.get_type(), r.get_proof(), new_arg_kinds);
}
public:
imp(app_builder & b, fun_info_manager & fm):
m_builder(b), m_fmanager(fm), m_ctx(fm.ctx()),
m_relation_info_getter(mk_relation_info_getter(fm.ctx().env())) {}
type_context & ctx() { return m_ctx; }
optional<result> mk_congr_simp(expr const & fn, unsigned nargs) {
fun_info finfo = m_fmanager.get(fn, nargs);
return mk_congr_simp(fn, nargs, finfo);
}
optional<result> mk_congr_simp(expr const & fn) {
fun_info finfo = m_fmanager.get(fn);
return mk_congr_simp(fn, finfo.get_arity(), finfo);
}
optional<result> mk_specialized_congr_simp(expr const & a) {
lean_assert(is_app(a));
expr g; unsigned prefix_sz, num_rest_args;
pre_specialize(a, g, prefix_sz, num_rest_args);
key k(g, num_rest_args);
auto it = m_simp_cache_spec.find(k);
if (it != m_simp_cache_spec.end())
return optional<result>(it->second);
auto r = mk_congr_simp(g, num_rest_args);
if (!r)
return optional<result>();
result new_r = mk_specialize_result(*r, prefix_sz);
m_simp_cache_spec.insert(mk_pair(k, new_r));
return optional<result>(new_r);
}
optional<result> mk_congr(expr const & fn, unsigned nargs) {
fun_info finfo = m_fmanager.get(fn, nargs);
return mk_congr(fn, nargs, finfo);
}
optional<result> mk_congr(expr const & fn) { optional<result> mk_congr(expr const & fn) {
fun_info finfo = m_fmanager.get(fn); fun_info finfo = m_fmanager.get(fn);
return mk_congr(fn, finfo.get_arity()); return mk_congr(fn, finfo.get_arity(), finfo);
}
optional<result> mk_specialized_congr(expr const & a) {
lean_assert(is_app(a));
expr g; unsigned prefix_sz, num_rest_args;
pre_specialize(a, g, prefix_sz, num_rest_args);
key k(g, num_rest_args);
auto it = m_cache_spec.find(k);
if (it != m_cache_spec.end())
return optional<result>(it->second);
auto r = mk_congr(g, num_rest_args);
if (!r) {
return optional<result>();
}
result new_r = mk_specialize_result(*r, prefix_sz);
m_cache_spec.insert(mk_pair(k, new_r));
return optional<result>(new_r);
} }
/** \brief Given an equivalence relation \c R, create the congruence lemma /** \brief Given an equivalence relation \c R, create the congruence lemma
@ -538,13 +603,18 @@ auto congr_lemma_manager::mk_congr_simp(expr const & fn) -> optional<result> {
auto congr_lemma_manager::mk_congr_simp(expr const & fn, unsigned nargs) -> optional<result> { auto congr_lemma_manager::mk_congr_simp(expr const & fn, unsigned nargs) -> optional<result> {
return m_ptr->mk_congr_simp(fn, nargs); return m_ptr->mk_congr_simp(fn, nargs);
} }
auto congr_lemma_manager::mk_specialized_congr_simp(expr const & a) -> optional<result> {
return m_ptr->mk_specialized_congr_simp(a);
}
auto congr_lemma_manager::mk_congr(expr const & fn) -> optional<result> { auto congr_lemma_manager::mk_congr(expr const & fn) -> optional<result> {
return m_ptr->mk_congr(fn); return m_ptr->mk_congr(fn);
} }
auto congr_lemma_manager::mk_congr(expr const & fn, unsigned nargs) -> optional<result> { auto congr_lemma_manager::mk_congr(expr const & fn, unsigned nargs) -> optional<result> {
return m_ptr->mk_congr(fn, nargs); return m_ptr->mk_congr(fn, nargs);
} }
auto congr_lemma_manager::mk_specialized_congr(expr const & fn) -> optional<result> {
return m_ptr->mk_specialized_congr(fn);
}
auto congr_lemma_manager::mk_rel_iff_congr(expr const & R) -> optional<result> { auto congr_lemma_manager::mk_rel_iff_congr(expr const & R) -> optional<result> {
return m_ptr->mk_rel_iff_congr(R); return m_ptr->mk_rel_iff_congr(R);
} }

View file

@ -49,9 +49,13 @@ public:
optional<result> mk_congr_simp(expr const & fn); optional<result> mk_congr_simp(expr const & fn);
optional<result> mk_congr_simp(expr const & fn, unsigned nargs); optional<result> mk_congr_simp(expr const & fn, unsigned nargs);
/* Create a specialized theorem using (a prefix of) the arguments of the given application. */
optional<result> mk_specialized_congr_simp(expr const & a);
optional<result> mk_congr(expr const & fn); optional<result> mk_congr(expr const & fn);
optional<result> mk_congr(expr const & fn, unsigned nargs); optional<result> mk_congr(expr const & fn, unsigned nargs);
/* Create a specialized theorem using (a prefix of) the arguments of the given application. */
optional<result> mk_specialized_congr(expr const & a);
/** \brief If R is an equivalence relation, construct the congruence lemma /** \brief If R is an equivalence relation, construct the congruence lemma

View file

@ -9,10 +9,27 @@ Author: Leonardo de Moura
#include "kernel/for_each_fn.h" #include "kernel/for_each_fn.h"
#include "kernel/instantiate.h" #include "kernel/instantiate.h"
#include "kernel/abstract.h" #include "kernel/abstract.h"
#include "library/trace.h"
#include "library/replace_visitor.h" #include "library/replace_visitor.h"
#include "library/fun_info_manager.h" #include "library/fun_info_manager.h"
namespace lean { namespace lean {
static name * g_fun_info = nullptr;
void initialize_fun_info_manager() {
g_fun_info = new name("fun_info");
register_trace_class(*g_fun_info);
}
void finalize_fun_info_manager() {
delete g_fun_info;
}
#define lean_trace_fun_info(Code) lean_trace(*g_fun_info, Code)
static bool is_fun_info_trace_enabled() {
return is_trace_class_enabled(*g_fun_info);
}
fun_info_manager::fun_info_manager(type_context & ctx): fun_info_manager::fun_info_manager(type_context & ctx):
m_ctx(ctx) { m_ctx(ctx) {
} }
@ -95,35 +112,13 @@ static bool has_nonprop_nonsubsingleton_fwd_dep(unsigned i, buffer<param_info> c
if (fwd_pinfo.is_prop() || fwd_pinfo.is_subsingleton()) if (fwd_pinfo.is_prop() || fwd_pinfo.is_subsingleton())
continue; continue;
auto const & fwd_deps = fwd_pinfo.get_dependencies(); auto const & fwd_deps = fwd_pinfo.get_dependencies();
if (std::find(fwd_deps.begin(), fwd_deps.end(), i) == fwd_deps.end()) { if (std::find(fwd_deps.begin(), fwd_deps.end(), i) != fwd_deps.end()) {
return true; return true;
} }
} }
return false; return false;
} }
fun_info fun_info_manager::get_specialization(expr const & fn, buffer<expr> const & args, buffer<param_info> const & pinfos, list<unsigned> const & result_deps) {
buffer<param_info> new_pinfos;
expr type = m_ctx.relaxed_try_to_pi(m_ctx.infer(fn));
for (unsigned i = 0; i < args.size(); i++) {
expr new_type = m_ctx.relaxed_try_to_pi(instantiate(binding_body(type), args[i]));
expr arg_type = binding_domain(type);
param_info new_pinfo = pinfos[i];
new_pinfo.m_specialized = true;
if (!new_pinfo.m_prop) {
new_pinfo.m_prop = m_ctx.is_prop(arg_type);
new_pinfo.m_subsingleton = new_pinfo.m_prop;
}
if (!new_pinfo.m_subsingleton) {
new_pinfo.m_subsingleton = static_cast<bool>(m_ctx.mk_subsingleton_instance(arg_type));
}
new_pinfos.push_back(new_pinfo);
type = new_type;
}
bool spec = true;
return fun_info(new_pinfos.size(), to_list(new_pinfos), result_deps, spec);
}
/* Copy the first prefix_sz entries from pinfos to new_pinfos and mark them as m_specialized = true */ /* Copy the first prefix_sz entries from pinfos to new_pinfos and mark them as m_specialized = true */
static void copy_prefix(unsigned prefix_sz, buffer<param_info> const & pinfos, buffer<param_info> & new_pinfos) { static void copy_prefix(unsigned prefix_sz, buffer<param_info> const & pinfos, buffer<param_info> & new_pinfos) {
for (unsigned i = 0; i < prefix_sz; i++) { for (unsigned i = 0; i < prefix_sz; i++) {
@ -131,7 +126,45 @@ static void copy_prefix(unsigned prefix_sz, buffer<param_info> const & pinfos, b
} }
} }
fun_info fun_info_manager::get_specialization(expr const & a) { void fun_info_manager::trace_if_unsupported(expr const & fn, buffer<expr> const & args, unsigned prefix_sz,
buffer<param_info> const & pinfos, fun_info const & result) {
if (!is_fun_info_trace_enabled())
return;
/* Check if all remaining arguments are nondependent or
dependent (but all forward dependencies are propositions or subsingletons) */
unsigned i = prefix_sz;
for (; i < pinfos.size(); i++) {
param_info const & pinfo = pinfos[i];
if (!pinfo.is_dep())
continue; /* nondependent argument */
if (has_nonprop_nonsubsingleton_fwd_dep(i, pinfos))
break; /* failed i-th argument has a forward dependent that is not a prop nor a subsingleton */
}
if (i == pinfos.size())
return; // It is *cheap* case
/* Expensive case */
/* We generate a trace message IF it would be possible to compute more precise information.
That is, there is an argument that is a proposition and/or subsingleton, but
the corresponding pinfo is not a marked a prop/subsingleton.
*/
i = 0;
for (param_info const & pinfo : result.get_params_info()) {
if (pinfo.is_prop() || pinfo.is_subsingleton())
continue;
expr arg_type = m_ctx.infer(args[i]);
if (m_ctx.is_prop(arg_type) || m_ctx.mk_subsingleton_instance(arg_type)) {
lean_trace_fun_info(
tout() << "approximating function information for '" << fn
<< "', this may affect the effectiveness of the simplifier and congruence closure modules, "
<< "more precise information can be efficiently computed if all parameters are moved to the beginning of the function\n";);
return;
}
i++;
}
}
fun_info fun_info_manager::get_specialized(expr const & a) {
lean_assert(is_app(a)); lean_assert(is_app(a));
buffer<expr> args; buffer<expr> args;
expr const & fn = get_app_args(a, args); expr const & fn = get_app_args(a, args);
@ -159,6 +192,10 @@ fun_info fun_info_manager::get_specialization(expr const & a) {
I don't think this is a big deal since we can write it as: I don't think this is a big deal since we can write it as:
p : Pi {A : Type} {B : Type} (x : A) (y : B), Prop p : Pi {A : Type} {B : Type} (x : A) (y : B), Prop
Therefore, we ignore the non-cheap cases, and pretend they are "cheap".
If tracing is enabled, we produce a tracing message whenever we find
a non-cheap case.
*/ */
buffer<param_info> pinfos; buffer<param_info> pinfos;
to_buffer(info.get_params_info(), pinfos); to_buffer(info.get_params_info(), pinfos);
@ -174,36 +211,26 @@ fun_info fun_info_manager::get_specialization(expr const & a) {
break; break;
} }
unsigned prefix_sz = i; unsigned prefix_sz = i;
/* Check if all remaining arguments are nondependent or if (prefix_sz == 0) {
dependent (but all forward dependencies are propositions or subsingletons) */ trace_if_unsupported(fn, args, prefix_sz, pinfos, info);
for (; i < pinfos.size(); i++) { return info;
param_info const & pinfo = pinfos[i];
if (!pinfo.is_dep())
continue; /* nondependent argument */
if (has_nonprop_nonsubsingleton_fwd_dep(i, pinfos))
break; /* failed i-th argument has a forward dependent that is not a prop nor a subsingleton */
} }
if (i < pinfos.size()) { /* Get g : fn + prefix */
/* Expensive case */ unsigned num_rest_args = pinfos.size() - prefix_sz;
return get_specialization(fn, args, pinfos, info.get_result_dependencies()); expr g = a;
} else { for (unsigned i = 0; i < num_rest_args; i++)
if (prefix_sz == 0) g = app_fn(g);
return info; expr_unsigned key(g, num_rest_args);
/* Get g : fn + prefix */ auto it = m_cache_get_spec.find(key);
unsigned num_rest_args = pinfos.size() - prefix_sz; if (it != m_cache_get_spec.end()) {
expr g = a; return it->second;
for (unsigned i = 0; i < num_rest_args; i++)
g = app_fn(g);
expr_unsigned key(g, num_rest_args);
auto it = m_cache_get_spec.find(key);
if (it != m_cache_get_spec.end())
return it->second;
buffer<param_info> new_pinfos;
copy_prefix(prefix_sz, pinfos, new_pinfos);
auto result_deps = get_core(g, new_pinfos, num_rest_args);
fun_info r(new_pinfos.size(), to_list(new_pinfos), result_deps);
m_cache_get_spec.insert(mk_pair(key, r));
return r;
} }
buffer<param_info> new_pinfos;
copy_prefix(prefix_sz, pinfos, new_pinfos);
auto result_deps = get_core(g, new_pinfos, num_rest_args);
fun_info r(new_pinfos.size(), to_list(new_pinfos), result_deps);
m_cache_get_spec.insert(mk_pair(key, r));
trace_if_unsupported(fn, args, prefix_sz, pinfos, r);
return r;
} }
} }

View file

@ -61,20 +61,16 @@ public:
/** \brief Function information produced by \c fun_info_manager */ /** \brief Function information produced by \c fun_info_manager */
class fun_info { class fun_info {
/* m_specialized is true if the information was produced using the function arguments,
and all m_specialized = true for all m_params_info */
unsigned m_arity; unsigned m_arity;
bool m_specialized;
list<param_info> m_params_info; list<param_info> m_params_info;
list<unsigned> m_deps; // resulting type dependencies list<unsigned> m_deps; // resulting type dependencies
public: public:
fun_info():m_arity(0), m_specialized(false) {} fun_info():m_arity(0) {}
fun_info(unsigned arity, list<param_info> const & info, list<unsigned> const & deps, bool spec = false): fun_info(unsigned arity, list<param_info> const & info, list<unsigned> const & deps):
m_arity(arity), m_specialized(spec), m_params_info(info), m_deps(deps) {} m_arity(arity), m_params_info(info), m_deps(deps) {}
unsigned get_arity() const { return m_arity; } unsigned get_arity() const { return m_arity; }
list<param_info> const & get_params_info() const { return m_params_info; } list<param_info> const & get_params_info() const { return m_params_info; }
list<unsigned> const & get_result_dependencies() const { return m_deps; } list<unsigned> const & get_result_dependencies() const { return m_deps; }
bool fully_specialized() const { return m_specialized; }
}; };
/** \brief Helper object for retrieving a summary for the parameters /** \brief Helper object for retrieving a summary for the parameters
@ -90,8 +86,8 @@ class fun_info_manager {
narg_cache m_cache_get_spec; narg_cache m_cache_get_spec;
list<unsigned> collect_deps(expr const & e, buffer<expr> const & locals); list<unsigned> collect_deps(expr const & e, buffer<expr> const & locals);
list<unsigned> get_core(expr const & e, buffer<param_info> & pinfos, unsigned max_args); list<unsigned> get_core(expr const & e, buffer<param_info> & pinfos, unsigned max_args);
fun_info get_specialization(expr const & fn, buffer<expr> const & args, void trace_if_unsupported(expr const & fn, buffer<expr> const & args, unsigned prefix_sz,
buffer<param_info> const & pinfos, list<unsigned> const & result_deps); buffer<param_info> const & pinfos, fun_info const & result);
public: public:
fun_info_manager(type_context & ctx); fun_info_manager(type_context & ctx);
type_context & ctx() { return m_ctx; } type_context & ctx() { return m_ctx; }
@ -115,6 +111,9 @@ public:
\remark \c get and \c get_specialization return the same result for all but \remark \c get and \c get_specialization return the same result for all but
is_prop and is_subsingleton. */ is_prop and is_subsingleton. */
fun_info get_specialization(expr const & app); fun_info get_specialized(expr const & app);
}; };
void initialize_fun_info_manager();
void finalize_fun_info_manager();
} }

View file

@ -47,6 +47,7 @@ Author: Leonardo de Moura
#include "library/congr_lemma_manager.h" #include "library/congr_lemma_manager.h"
#include "library/app_builder.h" #include "library/app_builder.h"
#include "library/attribute_manager.h" #include "library/attribute_manager.h"
#include "library/fun_info_manager.h"
namespace lean { namespace lean {
void initialize_library_module() { void initialize_library_module() {
@ -93,9 +94,11 @@ void initialize_library_module() {
initialize_light_rule_set(); initialize_light_rule_set();
initialize_congr_lemma_manager(); initialize_congr_lemma_manager();
initialize_app_builder(); initialize_app_builder();
initialize_fun_info_manager();
} }
void finalize_library_module() { void finalize_library_module() {
finalize_fun_info_manager();
finalize_app_builder(); finalize_app_builder();
finalize_congr_lemma_manager(); finalize_congr_lemma_manager();
finalize_light_rule_set(); finalize_light_rule_set();

View file

@ -0,0 +1,42 @@
import data.unit
open nat unit
constant f {A : Type} (a : A) {B : Type} (b : B) : nat
constant g : unit → nat
example (a b : unit) : g a = g b :=
by simp
example (a c : unit) (b d : nat) : b = d → f a b = f c d :=
by simp
constant h {A B : Type} : A → B → nat
example (a b c d : unit) : h a b = h c d :=
by simp
definition C [reducible] : nat → Type₁
| nat.zero := unit
| (nat.succ a) := nat
constant g₂ : Π {n : nat}, C n → nat → nat
example (a b : C zero) (c d : nat) : c = d → g₂ a c = g₂ b d :=
by simp
example (n : nat) (h : zero = n) (a b : C n) (c d : nat) : c = d → g₂ a c = g₂ b d :=
by simp
-- The following one cannot be solved as is
-- example (a c : nat) (b d : unit) : a = c → b = d → f a b = f c d :=
-- by simp
-- But, we can use the following trick
definition f_aux {A B : Type} (a : A) (b : B) := f a b
lemma to_f_aux [simp] {A B : Type} (a : A) (b : B) : f a b = f_aux a b :=
rfl
example (a c : nat) (b d : unit) : a = c → b = d → f a b = f c d :=
by simp