feat(library/fun_info_manager,library/congr_lemma_manager,blast/simplifier): specialized congruence lemmas
We still need a lot of polishing.
This commit is contained in:
parent
930fcddace
commit
f3b8aef24c
10 changed files with 353 additions and 149 deletions
|
@ -32,18 +32,27 @@ unsigned abstract_expr_manager::hash(expr const & e) {
|
|||
m_locals.pop_back();
|
||||
return h;
|
||||
case expr_kind::App:
|
||||
// TODO(Leo): in the past we only had to apply instantiate_rev to the function.
|
||||
// We have to improve this.
|
||||
// One idea is to compute and cache the specialization prefix for f.
|
||||
// Then, we only need to apply instantiate_rev to f + prefix.
|
||||
// expr f = instantiate_rev(get_app_args(e, args), m_locals.size(), m_locals.data());
|
||||
expr new_e = instantiate_rev(e, m_locals.size(), m_locals.data());
|
||||
optional<congr_lemma> f_congr = m_congr_lemma_manager.mk_specialized_congr(new_e);
|
||||
buffer<expr> args;
|
||||
expr f = instantiate_rev(get_app_args(e, args), m_locals.size(), m_locals.data());
|
||||
optional<congr_lemma> f_congr = m_congr_lemma_manager.mk_congr(f, args.size());
|
||||
expr const & f = get_app_args(new_e, args);
|
||||
h = hash(f);
|
||||
if (!f_congr) {
|
||||
for (expr const & arg : args) h = ::lean::hash(h, hash(arg));
|
||||
for (expr const & arg : args) {
|
||||
h = ::lean::hash(h, hash(arg));
|
||||
}
|
||||
} else {
|
||||
int i = -1;
|
||||
unsigned i = 0;
|
||||
for_each(f_congr->get_arg_kinds(), [&](congr_arg_kind const & c_kind) {
|
||||
if (c_kind != congr_arg_kind::Cast) {
|
||||
h = ::lean::hash(h, hash(args[i]));
|
||||
}
|
||||
i++;
|
||||
if (c_kind == congr_arg_kind::Cast) return;
|
||||
h = ::lean::hash(h, hash(args[i]));
|
||||
});
|
||||
}
|
||||
return h;
|
||||
|
@ -52,61 +61,70 @@ unsigned abstract_expr_manager::hash(expr const & e) {
|
|||
}
|
||||
|
||||
bool abstract_expr_manager::is_equal(expr const & a, expr const & b) {
|
||||
if (is_eqp(a, b)) return true;
|
||||
if (hash(a) != hash(b)) return false;
|
||||
if (a.kind() != b.kind()) return false;
|
||||
if (is_var(a)) return var_idx(a) == var_idx(b);
|
||||
bool is_eq;
|
||||
switch (a.kind()) {
|
||||
case expr_kind::Var:
|
||||
lean_unreachable(); // LCOV_EXCL_LINE
|
||||
case expr_kind::Constant: case expr_kind::Sort:
|
||||
return a == b;
|
||||
case expr_kind::Meta: case expr_kind::Local:
|
||||
return mlocal_name(a) == mlocal_name(b) && is_equal(mlocal_type(a), mlocal_type(b));
|
||||
case expr_kind::Lambda: case expr_kind::Pi:
|
||||
if (!is_equal(binding_domain(a), binding_domain(b))) return false;
|
||||
// see comment at abstract_expr_manager::hash
|
||||
m_locals.push_back(instantiate_rev(m_tctx.mk_tmp_local(binding_domain(a)), m_locals.size(), m_locals.data()));
|
||||
is_eq = is_equal(binding_body(a), binding_body(b));
|
||||
m_locals.pop_back();
|
||||
return is_eq;
|
||||
case expr_kind::Macro:
|
||||
if (macro_def(a) != macro_def(b) || macro_num_args(a) != macro_num_args(b))
|
||||
return false;
|
||||
for (unsigned i = 0; i < macro_num_args(a); i++) {
|
||||
if (!is_equal(macro_arg(a, i), macro_arg(b, i)))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
case expr_kind::App:
|
||||
buffer<expr> a_args, b_args;
|
||||
expr f_a = get_app_args(a, a_args);
|
||||
expr f_b = get_app_args(b, b_args);
|
||||
if (!is_equal(f_a, f_b)) return false;
|
||||
if (a_args.size() != b_args.size()) return false;
|
||||
expr f = instantiate_rev(f_a, m_locals.size(), m_locals.data());
|
||||
optional<congr_lemma> f_congr = m_congr_lemma_manager.mk_congr(f, a_args.size());
|
||||
bool not_equal = false;
|
||||
if (!f_congr) {
|
||||
for (unsigned i = 0; i < a_args.size(); ++i) {
|
||||
if (!is_equal(a_args[i], b_args[i])) {
|
||||
not_equal = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int i = -1;
|
||||
for_each(f_congr->get_arg_kinds(), [&](congr_arg_kind const & c_kind) {
|
||||
if (not_equal) return;
|
||||
i++;
|
||||
if (c_kind == congr_arg_kind::Cast) return;
|
||||
if (!is_equal(a_args[i], b_args[i])) not_equal = true;
|
||||
});
|
||||
}
|
||||
return !not_equal;
|
||||
}
|
||||
if (is_eqp(a, b)) return true;
|
||||
if (a.kind() != b.kind()) return false;
|
||||
if (is_var(a)) return var_idx(a) == var_idx(b);
|
||||
bool is_eq;
|
||||
switch (a.kind()) {
|
||||
case expr_kind::Var:
|
||||
lean_unreachable(); // LCOV_EXCL_LINE
|
||||
case expr_kind::Constant: case expr_kind::Sort:
|
||||
return a == b;
|
||||
case expr_kind::Meta: case expr_kind::Local:
|
||||
return mlocal_name(a) == mlocal_name(b) && is_equal(mlocal_type(a), mlocal_type(b));
|
||||
case expr_kind::Lambda: case expr_kind::Pi:
|
||||
if (!is_equal(binding_domain(a), binding_domain(b))) return false;
|
||||
// see comment at abstract_expr_manager::hash
|
||||
m_locals.push_back(instantiate_rev(m_tctx.mk_tmp_local(binding_domain(a)), m_locals.size(), m_locals.data()));
|
||||
is_eq = is_equal(binding_body(a), binding_body(b));
|
||||
m_locals.pop_back();
|
||||
return is_eq;
|
||||
case expr_kind::Macro:
|
||||
if (macro_def(a) != macro_def(b) || macro_num_args(a) != macro_num_args(b))
|
||||
return false;
|
||||
for (unsigned i = 0; i < macro_num_args(a); i++) {
|
||||
if (!is_equal(macro_arg(a, i), macro_arg(b, i)))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
case expr_kind::App:
|
||||
if (!is_equal(get_app_fn(a), get_app_fn(b))) {
|
||||
return false;
|
||||
}
|
||||
if (get_app_num_args(a) != get_app_num_args(b)) {
|
||||
return false;
|
||||
}
|
||||
// See comment in the hash function
|
||||
expr new_a = instantiate_rev(a, m_locals.size(), m_locals.data());
|
||||
expr new_b = instantiate_rev(b, m_locals.size(), m_locals.data());
|
||||
optional<congr_lemma> congra = m_congr_lemma_manager.mk_specialized_congr(new_a);
|
||||
optional<congr_lemma> congrb = m_congr_lemma_manager.mk_specialized_congr(new_b);
|
||||
buffer<expr> a_args, b_args;
|
||||
get_app_args(new_a, a_args);
|
||||
get_app_args(new_b, b_args);
|
||||
bool not_equal = false;
|
||||
if (!congra || !congrb) {
|
||||
for (unsigned i = 0; i < a_args.size(); ++i) {
|
||||
if (!is_equal(a_args[i], b_args[i])) {
|
||||
not_equal = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
unsigned i = 0;
|
||||
for_each2(congra->get_arg_kinds(),
|
||||
congrb->get_arg_kinds(),
|
||||
[&](congr_arg_kind const & ca_kind, congr_arg_kind const & cb_kind) {
|
||||
if (not_equal)
|
||||
return;
|
||||
if (ca_kind != cb_kind || (ca_kind != congr_arg_kind::Cast && !is_equal(a_args[i], b_args[i]))) {
|
||||
not_equal = true;
|
||||
}
|
||||
i++;
|
||||
});
|
||||
}
|
||||
return !not_equal;
|
||||
}
|
||||
lean_unreachable(); // LCOV_EXCL_LINE
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -719,6 +719,10 @@ public:
|
|||
return m_congr_lemma_manager.mk_congr_simp(fn);
|
||||
}
|
||||
|
||||
optional<congr_lemma> mk_specialized_congr_lemma_for_simp(expr const & fn) {
|
||||
return m_congr_lemma_manager.mk_specialized_congr_simp(fn);
|
||||
}
|
||||
|
||||
optional<congr_lemma> mk_congr_lemma(expr const & fn, unsigned num_args) {
|
||||
return m_congr_lemma_manager.mk_congr(fn, num_args);
|
||||
}
|
||||
|
@ -727,6 +731,10 @@ public:
|
|||
return m_congr_lemma_manager.mk_congr(fn);
|
||||
}
|
||||
|
||||
optional<congr_lemma> mk_specialized_congr_lemma(expr const & a) {
|
||||
return m_congr_lemma_manager.mk_specialized_congr(a);
|
||||
}
|
||||
|
||||
optional<congr_lemma> mk_rel_iff_congr(expr const & fn) {
|
||||
return m_congr_lemma_manager.mk_rel_iff_congr(fn);
|
||||
}
|
||||
|
@ -743,6 +751,10 @@ public:
|
|||
return m_fun_info_manager.get(fn, nargs);
|
||||
}
|
||||
|
||||
fun_info get_specialized_fun_info(expr const & a) {
|
||||
return m_fun_info_manager.get_specialized(a);
|
||||
}
|
||||
|
||||
unsigned abstract_hash(expr const & e) {
|
||||
return m_abstract_expr_manager.hash(e);
|
||||
}
|
||||
|
@ -1083,6 +1095,11 @@ optional<congr_lemma> mk_congr_lemma_for_simp(expr const & fn) {
|
|||
return g_blastenv->mk_congr_lemma_for_simp(fn);
|
||||
}
|
||||
|
||||
optional<congr_lemma> mk_specialized_congr_lemma_for_simp(expr const & a) {
|
||||
lean_assert(g_blastenv);
|
||||
return g_blastenv->mk_specialized_congr_lemma_for_simp(a);
|
||||
}
|
||||
|
||||
optional<congr_lemma> mk_congr_lemma(expr const & fn, unsigned num_args) {
|
||||
lean_assert(g_blastenv);
|
||||
return g_blastenv->mk_congr_lemma(fn, num_args);
|
||||
|
@ -1093,6 +1110,11 @@ optional<congr_lemma> mk_congr_lemma(expr const & fn) {
|
|||
return g_blastenv->mk_congr_lemma(fn);
|
||||
}
|
||||
|
||||
optional<congr_lemma> mk_specialized_congr_lemma(expr const & a) {
|
||||
lean_assert(g_blastenv);
|
||||
return g_blastenv->mk_specialized_congr_lemma(a);
|
||||
}
|
||||
|
||||
optional<congr_lemma> mk_rel_iff_congr(expr const & fn) {
|
||||
lean_assert(g_blastenv);
|
||||
return g_blastenv->mk_rel_iff_congr(fn);
|
||||
|
@ -1113,6 +1135,11 @@ fun_info get_fun_info(expr const & fn, unsigned nargs) {
|
|||
return g_blastenv->get_fun_info(fn, nargs);
|
||||
}
|
||||
|
||||
fun_info get_specialized_fun_info(expr const & a) {
|
||||
lean_assert(g_blastenv);
|
||||
return g_blastenv->get_specialized_fun_info(a);
|
||||
}
|
||||
|
||||
unsigned abstract_hash(expr const & e) {
|
||||
lean_assert(g_blastenv);
|
||||
return g_blastenv->abstract_hash(e);
|
||||
|
@ -1326,11 +1353,13 @@ void initialize_blast() {
|
|||
|
||||
register_trace_class_alias("app_builder", name({"blast", "event"}));
|
||||
register_trace_class_alias(name({"simplifier", "failure"}), name({"blast", "event"}));
|
||||
register_trace_class_alias("fun_info", name({"blast", "event"}));
|
||||
|
||||
register_trace_class_alias(name({"cc", "propagation"}), "blast");
|
||||
|
||||
register_trace_class_alias("blast", "blast_detailed");
|
||||
register_trace_class_alias("app_builder", "blast_detailed");
|
||||
register_trace_class_alias("fun_info", "blast_detailed");
|
||||
register_trace_class_alias(name({"simplifier", "failure"}), "blast_detailed");
|
||||
register_trace_class_alias(name({"cc", "merge"}), "blast_detailed");
|
||||
|
||||
|
|
|
@ -126,6 +126,10 @@ optional<expr> mk_class_instance(expr const & e);
|
|||
optional<congr_lemma> mk_congr_lemma_for_simp(expr const & fn, unsigned num_args);
|
||||
/** \brief Similar to previous procedure, but num_args == arith of fn */
|
||||
optional<congr_lemma> mk_congr_lemma_for_simp(expr const & fn);
|
||||
/** \brief Similar to previous procedures, but \c a is a function application,
|
||||
and the arguments are taken into account when computing the lemma.
|
||||
\pre is_app(a) */
|
||||
optional<congr_lemma> mk_specialized_congr_lemma_for_simp(expr const & a);
|
||||
|
||||
/** \brief Create a congruence lemma for the given function.
|
||||
\pre num_args <= arity of fn
|
||||
|
@ -140,6 +144,10 @@ optional<congr_lemma> mk_congr_lemma_for_simp(expr const & fn);
|
|||
optional<congr_lemma> mk_congr_lemma(expr const & fn, unsigned num_args);
|
||||
/** \brief Similar to previous procedure, but num_args == arith of fn */
|
||||
optional<congr_lemma> mk_congr_lemma(expr const & fn);
|
||||
/** \brief Similar to previous procedures, but \c a is a function application,
|
||||
and the arguments are taken into account when computing the lemma.
|
||||
\pre is_app(a) */
|
||||
optional<congr_lemma> mk_specialized_congr_lemma(expr const & a);
|
||||
optional<congr_lemma> mk_rel_iff_congr(expr const & fn);
|
||||
optional<congr_lemma> mk_rel_eq_congr(expr const & fn);
|
||||
|
||||
|
@ -148,6 +156,10 @@ fun_info get_fun_info(expr const & fn);
|
|||
/** \brief Retrieve information for the given function.
|
||||
\pre nargs <= arity fn. */
|
||||
fun_info get_fun_info(expr const & fn, unsigned nargs);
|
||||
/** \brief Retrieve information for the given function-application
|
||||
taking into account the actual arguments.
|
||||
\pre is_app(a) */
|
||||
fun_info get_specialized_fun_info(expr const & a);
|
||||
|
||||
/** \brief Hash and equality test for abstract expressions */
|
||||
unsigned abstract_hash(expr const & e);
|
||||
|
|
|
@ -267,17 +267,17 @@ optional<result> simplifier::cache_lookup(expr const & e) {
|
|||
if (e == e_old) return optional<result>(r_old);
|
||||
lean_assert(is_app(e_old));
|
||||
buffer<expr> new_args, old_args;
|
||||
expr const & f_new = get_app_args(e, new_args);
|
||||
DEBUG_CODE(expr const & f_new =) get_app_args(e, new_args);
|
||||
DEBUG_CODE(expr const & f_old =) get_app_args(e_old, old_args);
|
||||
lean_assert(f_new == f_old);
|
||||
auto congr_lemma = mk_congr_lemma(f_new, new_args.size());
|
||||
auto congr_lemma = mk_specialized_congr_lemma(e);
|
||||
if (!congr_lemma) return optional<result>();
|
||||
expr proof = congr_lemma->get_proof();
|
||||
expr type = congr_lemma->get_type();
|
||||
|
||||
unsigned i = 0;
|
||||
for_each(congr_lemma->get_arg_kinds(), [&](congr_arg_kind const & ckind) {
|
||||
lean_assert(ckind == congr_arg_kind::Cast || new_args[i] == old_args[i]);
|
||||
lean_assert(ckind == congr_arg_kind::Cast || new_args[i] == old_args[i], static_cast<unsigned>(ckind), new_args[i], old_args[i]);
|
||||
expr rfl;
|
||||
switch (ckind) {
|
||||
case congr_arg_kind::Fixed:
|
||||
|
@ -469,7 +469,7 @@ result simplifier::simplify_pi(expr const & e) {
|
|||
expr simplifier::unfold_reducible_instances(expr const & e) {
|
||||
buffer<expr> args;
|
||||
expr f = get_app_args(e, args);
|
||||
fun_info f_info = get_fun_info(f, args.size());
|
||||
fun_info f_info = get_specialized_fun_info(e);
|
||||
int i = -1;
|
||||
for_each(f_info.get_params_info(), [&](param_info const & p_info) {
|
||||
i++;
|
||||
|
@ -802,7 +802,7 @@ optional<result> simplifier::synth_congr(expr const & e, F && simp) {
|
|||
lean_assert(is_app(e));
|
||||
buffer<expr> args;
|
||||
expr f = get_app_args(e, args);
|
||||
auto congr_lemma = mk_congr_lemma_for_simp(f, args.size());
|
||||
auto congr_lemma = mk_specialized_congr_lemma_for_simp(e);
|
||||
if (!congr_lemma) return optional<result>();
|
||||
expr proof = congr_lemma->get_proof();
|
||||
expr type = congr_lemma->get_type();
|
||||
|
@ -849,7 +849,7 @@ optional<result> simplifier::synth_congr(expr const & e, F && simp) {
|
|||
expr simplifier::remove_unnecessary_casts(expr const & e) {
|
||||
buffer<expr> args;
|
||||
expr f = get_app_args(e, args);
|
||||
fun_info f_info = get_fun_info(f, args.size());
|
||||
fun_info f_info = get_specialized_fun_info(e);
|
||||
int i = -1;
|
||||
for_each(f_info.get_params_info(), [&](param_info const & p_info) {
|
||||
i++;
|
||||
|
|
|
@ -28,7 +28,9 @@ struct congr_lemma_manager::imp {
|
|||
typedef expr_unsigned key;
|
||||
typedef expr_unsigned_map<result> cache;
|
||||
cache m_simp_cache;
|
||||
cache m_simp_cache_spec;
|
||||
cache m_cache;
|
||||
cache m_cache_spec;
|
||||
cache m_rel_cache[2];
|
||||
relation_info_getter m_relation_info_getter;
|
||||
|
||||
|
@ -298,7 +300,7 @@ struct congr_lemma_manager::imp {
|
|||
return optional<result>(congr_type, congr_proof, to_list(kinds));
|
||||
}
|
||||
buffer<expr> rhss1;
|
||||
get_app_args(rhs1, rhss1);
|
||||
get_app_args_at_most(rhs1, rhss.size(), rhss1);
|
||||
lean_assert(rhss.size() == rhss1.size());
|
||||
expr a = mk_app(fn, i, rhss1.data());
|
||||
expr pr2 = m_builder.mk_eq_refl(a);
|
||||
|
@ -321,18 +323,10 @@ struct congr_lemma_manager::imp {
|
|||
}
|
||||
}
|
||||
|
||||
public:
|
||||
imp(app_builder & b, fun_info_manager & fm):
|
||||
m_builder(b), m_fmanager(fm), m_ctx(fm.ctx()),
|
||||
m_relation_info_getter(mk_relation_info_getter(fm.ctx().env())) {}
|
||||
|
||||
type_context & ctx() { return m_ctx; }
|
||||
|
||||
optional<result> mk_congr_simp(expr const & fn, unsigned nargs) {
|
||||
optional<result> mk_congr_simp(expr const & fn, unsigned nargs, fun_info const & finfo) {
|
||||
auto r = m_simp_cache.find(key(fn, nargs));
|
||||
if (r != m_simp_cache.end())
|
||||
return optional<result>(r->second);
|
||||
fun_info finfo = m_fmanager.get(fn, nargs);
|
||||
list<unsigned> const & result_deps = finfo.get_result_dependencies();
|
||||
buffer<congr_arg_kind> kinds;
|
||||
buffer<param_info> pinfos;
|
||||
|
@ -386,16 +380,10 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
optional<result> mk_congr_simp(expr const & fn) {
|
||||
fun_info finfo = m_fmanager.get(fn);
|
||||
return mk_congr_simp(fn, finfo.get_arity());
|
||||
}
|
||||
|
||||
optional<result> mk_congr(expr const & fn, unsigned nargs) {
|
||||
optional<result> mk_congr(expr const & fn, unsigned nargs, fun_info const & finfo) {
|
||||
auto r = m_cache.find(key(fn, nargs));
|
||||
if (r != m_cache.end())
|
||||
return optional<result>(r->second);
|
||||
fun_info finfo = m_fmanager.get(fn, nargs);
|
||||
optional<result> simp_lemma = mk_congr_simp(fn, nargs);
|
||||
if (!simp_lemma)
|
||||
return optional<result>();
|
||||
|
@ -427,9 +415,86 @@ public:
|
|||
return new_r;
|
||||
}
|
||||
|
||||
void pre_specialize(expr const & a, expr & g, unsigned & prefix_sz, unsigned & num_rest_args) {
|
||||
fun_info finfo = m_fmanager.get_specialized(a);
|
||||
prefix_sz = 0;
|
||||
for (param_info const & pinfo : finfo.get_params_info()) {
|
||||
if (!pinfo.specialized())
|
||||
break;
|
||||
prefix_sz++;
|
||||
}
|
||||
num_rest_args = finfo.get_arity() - prefix_sz;
|
||||
g = a;
|
||||
for (unsigned i = 0; i < num_rest_args; i++) {
|
||||
g = app_fn(g);
|
||||
}
|
||||
}
|
||||
|
||||
result mk_specialize_result(result const & r, unsigned prefix_sz) {
|
||||
list<congr_arg_kind> new_arg_kinds = r.get_arg_kinds();
|
||||
for (unsigned i = 0; i < prefix_sz; i++)
|
||||
new_arg_kinds = cons(congr_arg_kind::FixedNoParam, new_arg_kinds);
|
||||
return result(r.get_type(), r.get_proof(), new_arg_kinds);
|
||||
}
|
||||
|
||||
public:
|
||||
imp(app_builder & b, fun_info_manager & fm):
|
||||
m_builder(b), m_fmanager(fm), m_ctx(fm.ctx()),
|
||||
m_relation_info_getter(mk_relation_info_getter(fm.ctx().env())) {}
|
||||
|
||||
type_context & ctx() { return m_ctx; }
|
||||
|
||||
optional<result> mk_congr_simp(expr const & fn, unsigned nargs) {
|
||||
fun_info finfo = m_fmanager.get(fn, nargs);
|
||||
return mk_congr_simp(fn, nargs, finfo);
|
||||
}
|
||||
|
||||
optional<result> mk_congr_simp(expr const & fn) {
|
||||
fun_info finfo = m_fmanager.get(fn);
|
||||
return mk_congr_simp(fn, finfo.get_arity(), finfo);
|
||||
}
|
||||
|
||||
optional<result> mk_specialized_congr_simp(expr const & a) {
|
||||
lean_assert(is_app(a));
|
||||
expr g; unsigned prefix_sz, num_rest_args;
|
||||
pre_specialize(a, g, prefix_sz, num_rest_args);
|
||||
key k(g, num_rest_args);
|
||||
auto it = m_simp_cache_spec.find(k);
|
||||
if (it != m_simp_cache_spec.end())
|
||||
return optional<result>(it->second);
|
||||
auto r = mk_congr_simp(g, num_rest_args);
|
||||
if (!r)
|
||||
return optional<result>();
|
||||
result new_r = mk_specialize_result(*r, prefix_sz);
|
||||
m_simp_cache_spec.insert(mk_pair(k, new_r));
|
||||
return optional<result>(new_r);
|
||||
}
|
||||
|
||||
optional<result> mk_congr(expr const & fn, unsigned nargs) {
|
||||
fun_info finfo = m_fmanager.get(fn, nargs);
|
||||
return mk_congr(fn, nargs, finfo);
|
||||
}
|
||||
|
||||
optional<result> mk_congr(expr const & fn) {
|
||||
fun_info finfo = m_fmanager.get(fn);
|
||||
return mk_congr(fn, finfo.get_arity());
|
||||
return mk_congr(fn, finfo.get_arity(), finfo);
|
||||
}
|
||||
|
||||
optional<result> mk_specialized_congr(expr const & a) {
|
||||
lean_assert(is_app(a));
|
||||
expr g; unsigned prefix_sz, num_rest_args;
|
||||
pre_specialize(a, g, prefix_sz, num_rest_args);
|
||||
key k(g, num_rest_args);
|
||||
auto it = m_cache_spec.find(k);
|
||||
if (it != m_cache_spec.end())
|
||||
return optional<result>(it->second);
|
||||
auto r = mk_congr(g, num_rest_args);
|
||||
if (!r) {
|
||||
return optional<result>();
|
||||
}
|
||||
result new_r = mk_specialize_result(*r, prefix_sz);
|
||||
m_cache_spec.insert(mk_pair(k, new_r));
|
||||
return optional<result>(new_r);
|
||||
}
|
||||
|
||||
/** \brief Given an equivalence relation \c R, create the congruence lemma
|
||||
|
@ -538,13 +603,18 @@ auto congr_lemma_manager::mk_congr_simp(expr const & fn) -> optional<result> {
|
|||
auto congr_lemma_manager::mk_congr_simp(expr const & fn, unsigned nargs) -> optional<result> {
|
||||
return m_ptr->mk_congr_simp(fn, nargs);
|
||||
}
|
||||
auto congr_lemma_manager::mk_specialized_congr_simp(expr const & a) -> optional<result> {
|
||||
return m_ptr->mk_specialized_congr_simp(a);
|
||||
}
|
||||
auto congr_lemma_manager::mk_congr(expr const & fn) -> optional<result> {
|
||||
return m_ptr->mk_congr(fn);
|
||||
}
|
||||
auto congr_lemma_manager::mk_congr(expr const & fn, unsigned nargs) -> optional<result> {
|
||||
return m_ptr->mk_congr(fn, nargs);
|
||||
}
|
||||
|
||||
auto congr_lemma_manager::mk_specialized_congr(expr const & fn) -> optional<result> {
|
||||
return m_ptr->mk_specialized_congr(fn);
|
||||
}
|
||||
auto congr_lemma_manager::mk_rel_iff_congr(expr const & R) -> optional<result> {
|
||||
return m_ptr->mk_rel_iff_congr(R);
|
||||
}
|
||||
|
|
|
@ -49,9 +49,13 @@ public:
|
|||
|
||||
optional<result> mk_congr_simp(expr const & fn);
|
||||
optional<result> mk_congr_simp(expr const & fn, unsigned nargs);
|
||||
/* Create a specialized theorem using (a prefix of) the arguments of the given application. */
|
||||
optional<result> mk_specialized_congr_simp(expr const & a);
|
||||
|
||||
optional<result> mk_congr(expr const & fn);
|
||||
optional<result> mk_congr(expr const & fn, unsigned nargs);
|
||||
/* Create a specialized theorem using (a prefix of) the arguments of the given application. */
|
||||
optional<result> mk_specialized_congr(expr const & a);
|
||||
|
||||
/** \brief If R is an equivalence relation, construct the congruence lemma
|
||||
|
||||
|
|
|
@ -9,10 +9,27 @@ Author: Leonardo de Moura
|
|||
#include "kernel/for_each_fn.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "kernel/abstract.h"
|
||||
#include "library/trace.h"
|
||||
#include "library/replace_visitor.h"
|
||||
#include "library/fun_info_manager.h"
|
||||
|
||||
namespace lean {
|
||||
static name * g_fun_info = nullptr;
|
||||
void initialize_fun_info_manager() {
|
||||
g_fun_info = new name("fun_info");
|
||||
register_trace_class(*g_fun_info);
|
||||
}
|
||||
|
||||
void finalize_fun_info_manager() {
|
||||
delete g_fun_info;
|
||||
}
|
||||
|
||||
#define lean_trace_fun_info(Code) lean_trace(*g_fun_info, Code)
|
||||
|
||||
static bool is_fun_info_trace_enabled() {
|
||||
return is_trace_class_enabled(*g_fun_info);
|
||||
}
|
||||
|
||||
fun_info_manager::fun_info_manager(type_context & ctx):
|
||||
m_ctx(ctx) {
|
||||
}
|
||||
|
@ -95,35 +112,13 @@ static bool has_nonprop_nonsubsingleton_fwd_dep(unsigned i, buffer<param_info> c
|
|||
if (fwd_pinfo.is_prop() || fwd_pinfo.is_subsingleton())
|
||||
continue;
|
||||
auto const & fwd_deps = fwd_pinfo.get_dependencies();
|
||||
if (std::find(fwd_deps.begin(), fwd_deps.end(), i) == fwd_deps.end()) {
|
||||
if (std::find(fwd_deps.begin(), fwd_deps.end(), i) != fwd_deps.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
fun_info fun_info_manager::get_specialization(expr const & fn, buffer<expr> const & args, buffer<param_info> const & pinfos, list<unsigned> const & result_deps) {
|
||||
buffer<param_info> new_pinfos;
|
||||
expr type = m_ctx.relaxed_try_to_pi(m_ctx.infer(fn));
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
expr new_type = m_ctx.relaxed_try_to_pi(instantiate(binding_body(type), args[i]));
|
||||
expr arg_type = binding_domain(type);
|
||||
param_info new_pinfo = pinfos[i];
|
||||
new_pinfo.m_specialized = true;
|
||||
if (!new_pinfo.m_prop) {
|
||||
new_pinfo.m_prop = m_ctx.is_prop(arg_type);
|
||||
new_pinfo.m_subsingleton = new_pinfo.m_prop;
|
||||
}
|
||||
if (!new_pinfo.m_subsingleton) {
|
||||
new_pinfo.m_subsingleton = static_cast<bool>(m_ctx.mk_subsingleton_instance(arg_type));
|
||||
}
|
||||
new_pinfos.push_back(new_pinfo);
|
||||
type = new_type;
|
||||
}
|
||||
bool spec = true;
|
||||
return fun_info(new_pinfos.size(), to_list(new_pinfos), result_deps, spec);
|
||||
}
|
||||
|
||||
/* Copy the first prefix_sz entries from pinfos to new_pinfos and mark them as m_specialized = true */
|
||||
static void copy_prefix(unsigned prefix_sz, buffer<param_info> const & pinfos, buffer<param_info> & new_pinfos) {
|
||||
for (unsigned i = 0; i < prefix_sz; i++) {
|
||||
|
@ -131,7 +126,45 @@ static void copy_prefix(unsigned prefix_sz, buffer<param_info> const & pinfos, b
|
|||
}
|
||||
}
|
||||
|
||||
fun_info fun_info_manager::get_specialization(expr const & a) {
|
||||
void fun_info_manager::trace_if_unsupported(expr const & fn, buffer<expr> const & args, unsigned prefix_sz,
|
||||
buffer<param_info> const & pinfos, fun_info const & result) {
|
||||
if (!is_fun_info_trace_enabled())
|
||||
return;
|
||||
/* Check if all remaining arguments are nondependent or
|
||||
dependent (but all forward dependencies are propositions or subsingletons) */
|
||||
unsigned i = prefix_sz;
|
||||
for (; i < pinfos.size(); i++) {
|
||||
param_info const & pinfo = pinfos[i];
|
||||
if (!pinfo.is_dep())
|
||||
continue; /* nondependent argument */
|
||||
if (has_nonprop_nonsubsingleton_fwd_dep(i, pinfos))
|
||||
break; /* failed i-th argument has a forward dependent that is not a prop nor a subsingleton */
|
||||
}
|
||||
if (i == pinfos.size())
|
||||
return; // It is *cheap* case
|
||||
|
||||
/* Expensive case */
|
||||
/* We generate a trace message IF it would be possible to compute more precise information.
|
||||
That is, there is an argument that is a proposition and/or subsingleton, but
|
||||
the corresponding pinfo is not a marked a prop/subsingleton.
|
||||
*/
|
||||
i = 0;
|
||||
for (param_info const & pinfo : result.get_params_info()) {
|
||||
if (pinfo.is_prop() || pinfo.is_subsingleton())
|
||||
continue;
|
||||
expr arg_type = m_ctx.infer(args[i]);
|
||||
if (m_ctx.is_prop(arg_type) || m_ctx.mk_subsingleton_instance(arg_type)) {
|
||||
lean_trace_fun_info(
|
||||
tout() << "approximating function information for '" << fn
|
||||
<< "', this may affect the effectiveness of the simplifier and congruence closure modules, "
|
||||
<< "more precise information can be efficiently computed if all parameters are moved to the beginning of the function\n";);
|
||||
return;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
fun_info fun_info_manager::get_specialized(expr const & a) {
|
||||
lean_assert(is_app(a));
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(a, args);
|
||||
|
@ -159,6 +192,10 @@ fun_info fun_info_manager::get_specialization(expr const & a) {
|
|||
I don't think this is a big deal since we can write it as:
|
||||
|
||||
p : Pi {A : Type} {B : Type} (x : A) (y : B), Prop
|
||||
|
||||
Therefore, we ignore the non-cheap cases, and pretend they are "cheap".
|
||||
If tracing is enabled, we produce a tracing message whenever we find
|
||||
a non-cheap case.
|
||||
*/
|
||||
buffer<param_info> pinfos;
|
||||
to_buffer(info.get_params_info(), pinfos);
|
||||
|
@ -174,36 +211,26 @@ fun_info fun_info_manager::get_specialization(expr const & a) {
|
|||
break;
|
||||
}
|
||||
unsigned prefix_sz = i;
|
||||
/* Check if all remaining arguments are nondependent or
|
||||
dependent (but all forward dependencies are propositions or subsingletons) */
|
||||
for (; i < pinfos.size(); i++) {
|
||||
param_info const & pinfo = pinfos[i];
|
||||
if (!pinfo.is_dep())
|
||||
continue; /* nondependent argument */
|
||||
if (has_nonprop_nonsubsingleton_fwd_dep(i, pinfos))
|
||||
break; /* failed i-th argument has a forward dependent that is not a prop nor a subsingleton */
|
||||
if (prefix_sz == 0) {
|
||||
trace_if_unsupported(fn, args, prefix_sz, pinfos, info);
|
||||
return info;
|
||||
}
|
||||
if (i < pinfos.size()) {
|
||||
/* Expensive case */
|
||||
return get_specialization(fn, args, pinfos, info.get_result_dependencies());
|
||||
} else {
|
||||
if (prefix_sz == 0)
|
||||
return info;
|
||||
/* Get g : fn + prefix */
|
||||
unsigned num_rest_args = pinfos.size() - prefix_sz;
|
||||
expr g = a;
|
||||
for (unsigned i = 0; i < num_rest_args; i++)
|
||||
g = app_fn(g);
|
||||
expr_unsigned key(g, num_rest_args);
|
||||
auto it = m_cache_get_spec.find(key);
|
||||
if (it != m_cache_get_spec.end())
|
||||
return it->second;
|
||||
buffer<param_info> new_pinfos;
|
||||
copy_prefix(prefix_sz, pinfos, new_pinfos);
|
||||
auto result_deps = get_core(g, new_pinfos, num_rest_args);
|
||||
fun_info r(new_pinfos.size(), to_list(new_pinfos), result_deps);
|
||||
m_cache_get_spec.insert(mk_pair(key, r));
|
||||
return r;
|
||||
/* Get g : fn + prefix */
|
||||
unsigned num_rest_args = pinfos.size() - prefix_sz;
|
||||
expr g = a;
|
||||
for (unsigned i = 0; i < num_rest_args; i++)
|
||||
g = app_fn(g);
|
||||
expr_unsigned key(g, num_rest_args);
|
||||
auto it = m_cache_get_spec.find(key);
|
||||
if (it != m_cache_get_spec.end()) {
|
||||
return it->second;
|
||||
}
|
||||
buffer<param_info> new_pinfos;
|
||||
copy_prefix(prefix_sz, pinfos, new_pinfos);
|
||||
auto result_deps = get_core(g, new_pinfos, num_rest_args);
|
||||
fun_info r(new_pinfos.size(), to_list(new_pinfos), result_deps);
|
||||
m_cache_get_spec.insert(mk_pair(key, r));
|
||||
trace_if_unsupported(fn, args, prefix_sz, pinfos, r);
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -61,20 +61,16 @@ public:
|
|||
|
||||
/** \brief Function information produced by \c fun_info_manager */
|
||||
class fun_info {
|
||||
/* m_specialized is true if the information was produced using the function arguments,
|
||||
and all m_specialized = true for all m_params_info */
|
||||
unsigned m_arity;
|
||||
bool m_specialized;
|
||||
list<param_info> m_params_info;
|
||||
list<unsigned> m_deps; // resulting type dependencies
|
||||
public:
|
||||
fun_info():m_arity(0), m_specialized(false) {}
|
||||
fun_info(unsigned arity, list<param_info> const & info, list<unsigned> const & deps, bool spec = false):
|
||||
m_arity(arity), m_specialized(spec), m_params_info(info), m_deps(deps) {}
|
||||
fun_info():m_arity(0) {}
|
||||
fun_info(unsigned arity, list<param_info> const & info, list<unsigned> const & deps):
|
||||
m_arity(arity), m_params_info(info), m_deps(deps) {}
|
||||
unsigned get_arity() const { return m_arity; }
|
||||
list<param_info> const & get_params_info() const { return m_params_info; }
|
||||
list<unsigned> const & get_result_dependencies() const { return m_deps; }
|
||||
bool fully_specialized() const { return m_specialized; }
|
||||
};
|
||||
|
||||
/** \brief Helper object for retrieving a summary for the parameters
|
||||
|
@ -90,8 +86,8 @@ class fun_info_manager {
|
|||
narg_cache m_cache_get_spec;
|
||||
list<unsigned> collect_deps(expr const & e, buffer<expr> const & locals);
|
||||
list<unsigned> get_core(expr const & e, buffer<param_info> & pinfos, unsigned max_args);
|
||||
fun_info get_specialization(expr const & fn, buffer<expr> const & args,
|
||||
buffer<param_info> const & pinfos, list<unsigned> const & result_deps);
|
||||
void trace_if_unsupported(expr const & fn, buffer<expr> const & args, unsigned prefix_sz,
|
||||
buffer<param_info> const & pinfos, fun_info const & result);
|
||||
public:
|
||||
fun_info_manager(type_context & ctx);
|
||||
type_context & ctx() { return m_ctx; }
|
||||
|
@ -115,6 +111,9 @@ public:
|
|||
|
||||
\remark \c get and \c get_specialization return the same result for all but
|
||||
is_prop and is_subsingleton. */
|
||||
fun_info get_specialization(expr const & app);
|
||||
fun_info get_specialized(expr const & app);
|
||||
};
|
||||
|
||||
void initialize_fun_info_manager();
|
||||
void finalize_fun_info_manager();
|
||||
}
|
||||
|
|
|
@ -47,6 +47,7 @@ Author: Leonardo de Moura
|
|||
#include "library/congr_lemma_manager.h"
|
||||
#include "library/app_builder.h"
|
||||
#include "library/attribute_manager.h"
|
||||
#include "library/fun_info_manager.h"
|
||||
|
||||
namespace lean {
|
||||
void initialize_library_module() {
|
||||
|
@ -93,9 +94,11 @@ void initialize_library_module() {
|
|||
initialize_light_rule_set();
|
||||
initialize_congr_lemma_manager();
|
||||
initialize_app_builder();
|
||||
initialize_fun_info_manager();
|
||||
}
|
||||
|
||||
void finalize_library_module() {
|
||||
finalize_fun_info_manager();
|
||||
finalize_app_builder();
|
||||
finalize_congr_lemma_manager();
|
||||
finalize_light_rule_set();
|
||||
|
|
42
tests/lean/run/blast_simp_subsingleton.lean
Normal file
42
tests/lean/run/blast_simp_subsingleton.lean
Normal file
|
@ -0,0 +1,42 @@
|
|||
import data.unit
|
||||
open nat unit
|
||||
|
||||
constant f {A : Type} (a : A) {B : Type} (b : B) : nat
|
||||
|
||||
constant g : unit → nat
|
||||
|
||||
example (a b : unit) : g a = g b :=
|
||||
by simp
|
||||
|
||||
example (a c : unit) (b d : nat) : b = d → f a b = f c d :=
|
||||
by simp
|
||||
|
||||
constant h {A B : Type} : A → B → nat
|
||||
|
||||
example (a b c d : unit) : h a b = h c d :=
|
||||
by simp
|
||||
|
||||
definition C [reducible] : nat → Type₁
|
||||
| nat.zero := unit
|
||||
| (nat.succ a) := nat
|
||||
|
||||
constant g₂ : Π {n : nat}, C n → nat → nat
|
||||
|
||||
example (a b : C zero) (c d : nat) : c = d → g₂ a c = g₂ b d :=
|
||||
by simp
|
||||
|
||||
example (n : nat) (h : zero = n) (a b : C n) (c d : nat) : c = d → g₂ a c = g₂ b d :=
|
||||
by simp
|
||||
|
||||
-- The following one cannot be solved as is
|
||||
-- example (a c : nat) (b d : unit) : a = c → b = d → f a b = f c d :=
|
||||
-- by simp
|
||||
-- But, we can use the following trick
|
||||
|
||||
definition f_aux {A B : Type} (a : A) (b : B) := f a b
|
||||
|
||||
lemma to_f_aux [simp] {A B : Type} (a : A) (b : B) : f a b = f_aux a b :=
|
||||
rfl
|
||||
|
||||
example (a c : nat) (b d : unit) : a = c → b = d → f a b = f c d :=
|
||||
by simp
|
Loading…
Reference in a new issue