feat(library/blast/congruence_closure): basic support for heterogeneous equality

We still have to process the general congruence lemmas.
This commit is contained in:
Leonardo de Moura 2016-01-10 12:41:30 -08:00
parent 22a6b7f1c3
commit 934f3b67ff
7 changed files with 268 additions and 76 deletions

View file

@ -587,6 +587,8 @@ struct app_builder::imp {
} }
expr mk_eq_of_heq(expr const & H) { expr mk_eq_of_heq(expr const & H) {
if (is_constant(get_app_fn(H), get_heq_of_eq_name()))
return app_arg(H);
expr p = m_ctx->relaxed_whnf(m_ctx->infer(H)); expr p = m_ctx->relaxed_whnf(m_ctx->infer(H));
expr A, a, B, b; expr A, a, B, b;
if (!is_heq(p, A, a, B, b)) { if (!is_heq(p, A, a, B, b)) {

View file

@ -731,6 +731,10 @@ public:
return m_congr_lemma_manager.mk_congr(fn); return m_congr_lemma_manager.mk_congr(fn);
} }
optional<congr_lemma> mk_hcongr_lemma(expr const & fn, unsigned num_args) {
return m_congr_lemma_manager.mk_hcongr(fn, num_args);
}
optional<congr_lemma> mk_specialized_congr_lemma(expr const & a) { optional<congr_lemma> mk_specialized_congr_lemma(expr const & a) {
return m_congr_lemma_manager.mk_specialized_congr(a); return m_congr_lemma_manager.mk_specialized_congr(a);
} }
@ -1114,6 +1118,11 @@ optional<congr_lemma> mk_congr_lemma(expr const & fn) {
return g_blastenv->mk_congr_lemma(fn); return g_blastenv->mk_congr_lemma(fn);
} }
optional<congr_lemma> mk_hcongr_lemma(expr const & fn, unsigned num_args) {
lean_assert(g_blastenv);
return g_blastenv->mk_hcongr_lemma(fn, num_args);
}
optional<congr_lemma> mk_specialized_congr_lemma(expr const & a) { optional<congr_lemma> mk_specialized_congr_lemma(expr const & a) {
lean_assert(g_blastenv); lean_assert(g_blastenv);
return g_blastenv->mk_specialized_congr_lemma(a); return g_blastenv->mk_specialized_congr_lemma(a);

View file

@ -148,6 +148,9 @@ optional<congr_lemma> mk_congr_lemma(expr const & fn);
and the arguments are taken into account when computing the lemma. and the arguments are taken into account when computing the lemma.
\pre is_app(a) */ \pre is_app(a) */
optional<congr_lemma> mk_specialized_congr_lemma(expr const & a); optional<congr_lemma> mk_specialized_congr_lemma(expr const & a);
/** \brief Create more general congruence lemma based on heterogeneous equality. */
optional<congr_lemma> mk_hcongr_lemma(expr const & fn, unsigned num_args);
optional<congr_lemma> mk_rel_iff_congr(expr const & fn); optional<congr_lemma> mk_rel_iff_congr(expr const & fn);
optional<congr_lemma> mk_rel_eq_congr(expr const & fn); optional<congr_lemma> mk_rel_eq_congr(expr const & fn);

View file

@ -18,7 +18,7 @@ Author: Leonardo de Moura
#include "library/blast/options.h" #include "library/blast/options.h"
#ifndef LEAN_DEFAULT_BLAST_CC_HEQ #ifndef LEAN_DEFAULT_BLAST_CC_HEQ
#define LEAN_DEFAULT_BLAST_CC_HEQ true #define LEAN_DEFAULT_BLAST_CC_HEQ false
#endif #endif
namespace lean { namespace lean {
@ -51,9 +51,25 @@ struct congr_lemma_key_eq_fn {
} }
}; };
LEAN_THREAD_VALUE(bool, g_heq_based, false);
static list<optional<name>> rel_names_from_arg_kinds(list<congr_arg_kind> const & kinds, name const & R) { static list<optional<name>> rel_names_from_arg_kinds(list<congr_arg_kind> const & kinds, name const & R) {
return map2<optional<name>>(kinds, [&](congr_arg_kind k) { return map2<optional<name>>(kinds, [&](congr_arg_kind k) {
return k == congr_arg_kind::Eq ? optional<name>(R) : optional<name>(); switch (k) {
case congr_arg_kind::Eq:
return optional<name>(R);
case congr_arg_kind::HEq:
if (g_heq_based && (R == get_eq_name() || R == get_heq_name())) {
/* Remark: we store equality and heterogeneous equality in the same class. */
return optional<name>(get_eq_name());
} else {
return optional<name>();
}
case congr_arg_kind::Fixed: case congr_arg_kind::Cast:
case congr_arg_kind::FixedNoParam:
return optional<name>();
}
lean_unreachable();
}); });
} }
@ -62,23 +78,26 @@ ext_congr_lemma::ext_congr_lemma(congr_lemma const & H):
m_congr_lemma(H), m_congr_lemma(H),
m_rel_names(rel_names_from_arg_kinds(H.get_arg_kinds(), get_eq_name())), m_rel_names(rel_names_from_arg_kinds(H.get_arg_kinds(), get_eq_name())),
m_lift_needed(false), m_lift_needed(false),
m_fixed_fun(true) {} m_fixed_fun(true),
ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed): m_heq_based(false) {}
ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed, bool heq_based):
m_R(R), m_R(R),
m_congr_lemma(H), m_congr_lemma(H),
m_rel_names(rel_names_from_arg_kinds(H.get_arg_kinds(), get_eq_name())), m_rel_names(rel_names_from_arg_kinds(H.get_arg_kinds(), get_eq_name())),
m_lift_needed(lift_needed), m_lift_needed(lift_needed),
m_fixed_fun(true) {} m_fixed_fun(true),
m_heq_based(heq_based) {}
ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, list<optional<name>> const & rel_names, bool lift_needed): ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, list<optional<name>> const & rel_names, bool lift_needed):
m_R(R), m_R(R),
m_congr_lemma(H), m_congr_lemma(H),
m_rel_names(rel_names), m_rel_names(rel_names),
m_lift_needed(lift_needed), m_lift_needed(lift_needed),
m_fixed_fun(true) {} m_fixed_fun(true),
m_heq_based(false) {}
/* We use the following cache for user-defined lemmas and automatically generated ones. */ /* We use the following cache for user-defined lemmas and automatically generated ones. */
typedef std::unordered_map<congr_lemma_key, optional<ext_congr_lemma>, congr_lemma_key_hash_fn, congr_lemma_key_eq_fn> congr_cache; typedef std::unordered_map<congr_lemma_key, optional<ext_congr_lemma>, congr_lemma_key_hash_fn, congr_lemma_key_eq_fn> congr_cache;
typedef std::tuple<name, expr, expr, expr> cc_todo_entry; typedef std::tuple<name, expr, expr, expr, bool> cc_todo_entry;
static expr * g_congr_mark = nullptr; // dummy congruence proof, it is just a placeholder. static expr * g_congr_mark = nullptr; // dummy congruence proof, it is just a placeholder.
static expr * g_iff_true_mark = nullptr; // dummy iff_true proof, it is just a placeholder. static expr * g_iff_true_mark = nullptr; // dummy iff_true proof, it is just a placeholder.
@ -96,13 +115,14 @@ static void clear_todo() {
get_todo().clear(); get_todo().clear();
} }
static void push_todo(name const & R, expr const & lhs, expr const & rhs, expr const & H) { static void push_todo(name const & R, expr const & lhs, expr const & rhs, expr const & H, bool heq_proof) {
get_todo().emplace_back(R, lhs, rhs, H); get_todo().emplace_back(R, lhs, rhs, H, heq_proof);
} }
scope_congruence_closure::scope_congruence_closure(): scope_congruence_closure::scope_congruence_closure():
m_old_cache(g_congr_cache) { m_old_cache(g_congr_cache) {
g_congr_cache = new congr_cache(); g_congr_cache = new congr_cache();
g_heq_based = is_standard(env()) && get_blast_cc_heq(ios().get_options());
} }
scope_congruence_closure::~scope_congruence_closure() { scope_congruence_closure::~scope_congruence_closure() {
@ -136,6 +156,7 @@ void congruence_closure::mk_entry_core(name const & R, expr const & e, bool to_p
n.m_interpreted = interpreted; n.m_interpreted = interpreted;
n.m_constructor = constructor; n.m_constructor = constructor;
n.m_to_propagate = to_propagate; n.m_to_propagate = to_propagate;
n.m_heq_proofs = false;
n.m_mt = m_gmt; n.m_mt = m_gmt;
m_entries.insert(eqc_key(R, e), n); m_entries.insert(eqc_key(R, e), n);
if (R != get_eq_name()) { if (R != get_eq_name()) {
@ -146,7 +167,8 @@ void congruence_closure::mk_entry_core(name const & R, expr const & e, bool to_p
expr it = n->m_next; expr it = n->m_next;
while (it != e) { while (it != e) {
if (m_entries.find(eqc_key(R, it))) { if (m_entries.find(eqc_key(R, it))) {
push_todo(R, e, it, *g_lift_mark); bool heq_proof = false;
push_todo(R, e, it, *g_lift_mark, heq_proof);
break; break;
} }
auto it_n = m_entries.find(eqc_key(get_eq_name(), it)); auto it_n = m_entries.find(eqc_key(get_eq_name(), it));
@ -332,7 +354,8 @@ static optional<ext_congr_lemma> mk_relation_congr_lemma(name const & R, expr co
if (R == get_iff_name()) { if (R == get_iff_name()) {
if (optional<congr_lemma> cgr = mk_rel_iff_congr(fn)) { if (optional<congr_lemma> cgr = mk_rel_iff_congr(fn)) {
auto child_rel_names = rel_names_from_arg_kinds(cgr->get_arg_kinds(), const_name(fn)); auto child_rel_names = rel_names_from_arg_kinds(cgr->get_arg_kinds(), const_name(fn));
return optional<ext_congr_lemma>(R, *cgr, child_rel_names, false); bool lift_needed = false;
return optional<ext_congr_lemma>(R, *cgr, child_rel_names, lift_needed);
} }
} }
if (optional<congr_lemma> cgr = mk_rel_eq_congr(fn)) { if (optional<congr_lemma> cgr = mk_rel_eq_congr(fn)) {
@ -352,13 +375,40 @@ static optional<ext_congr_lemma> mk_ext_specialized_congr_lemma(name const & R,
if (!eq_congr) if (!eq_congr)
return optional<ext_congr_lemma>(); return optional<ext_congr_lemma>();
ext_congr_lemma res1(*eq_congr); ext_congr_lemma res1(*eq_congr);
// If all arguments are Eq kind, then we can use generic congr axiom and consider equality for the function. /* If all arguments are Eq kind, then we can use generic congr axiom and consider equality for the function. */
if (eq_congr->all_eq_kind()) if (eq_congr->all_eq_kind())
res1.m_fixed_fun = false; res1.m_fixed_fun = false;
if (R == get_eq_name()) if (R == get_eq_name())
return optional<ext_congr_lemma>(res1); return optional<ext_congr_lemma>(res1);
bool lift_needed = true; bool lift_needed = true;
return optional<ext_congr_lemma>(R, *eq_congr, lift_needed); bool heq_based = false;
return optional<ext_congr_lemma>(R, *eq_congr, lift_needed, heq_based);
}
/* Automatically generated congruence lemma based on heterogeneous equality. */
static optional<ext_congr_lemma> mk_hcongr_lemma(name const & R, expr const & fn, unsigned nargs) {
optional<congr_lemma> eq_congr = mk_hcongr_lemma(fn, nargs);
if (!eq_congr)
return optional<ext_congr_lemma>();
ext_congr_lemma res1(*eq_congr);
/* If all arguments are Eq kind, then we can use generic congr axiom and consider equality for the function. */
if (eq_congr->all_eq_kind())
res1.m_fixed_fun = false;
if (R == get_eq_name() || R == get_heq_name())
return optional<ext_congr_lemma>(res1);
/* If R is not equality (=) nor heterogeneous equality (==),
we try to lift, but we can only lift if the congruence lemma produces an equality. */
expr type = eq_congr->get_type();
while (is_pi(type)) type = binding_body(type);
lean_assert(is_eq(type) || is_heq(type));
if (is_heq(type)) {
/* We cannot lift heterogeneous equality. */
return optional<ext_congr_lemma>();
} else {
bool heq_based = !eq_congr->all_eq_kind() || is_heq(type);
bool lift_needed = true;
return optional<ext_congr_lemma>(R, *eq_congr, lift_needed, heq_based);
}
} }
optional<ext_congr_lemma> mk_ext_congr_lemma(name const & R, expr const & e) { optional<ext_congr_lemma> mk_ext_congr_lemma(name const & R, expr const & e) {
@ -369,6 +419,25 @@ optional<ext_congr_lemma> mk_ext_congr_lemma(name const & R, expr const & e) {
auto it1 = g_congr_cache->find(key1); auto it1 = g_congr_cache->find(key1);
if (it1 != g_congr_cache->end()) if (it1 != g_congr_cache->end())
return it1->second; return it1->second;
if (g_heq_based) {
/* Check if there is user defined lemma for (R, fn, nargs).
Remark: specialization prefix is irrelevan for used defined congruence lemmas. */
auto lemma = mk_ext_user_congr_lemma(R, fn, nargs);
/* Try automatically generated lemma for equivalence relation over iff/eq */
if (!lemma) lemma = mk_relation_congr_lemma(R, fn, nargs);
/* Try automatically generated congruence lemma with support for heterogeneous equality. */
if (!lemma) lemma = mk_hcongr_lemma(R, fn, nargs);
if (lemma) {
/* succeeded */
g_congr_cache->insert(mk_pair(key1, lemma));
return lemma;
}
} else {
lean_assert(!g_heq_based);
/* When heterogeneous equality support is disabled, we use specialization prefix +
congruence lemmas that take care of subsingletons */
/* Check if (g := fn+specialization prefix) is in the cache */ /* Check if (g := fn+specialization prefix) is in the cache */
unsigned prefix_sz = get_specialization_prefix_size(fn, nargs); unsigned prefix_sz = get_specialization_prefix_size(fn, nargs);
unsigned rest_nargs = nargs - prefix_sz; unsigned rest_nargs = nargs - prefix_sz;
@ -397,6 +466,7 @@ optional<ext_congr_lemma> mk_ext_congr_lemma(name const & R, expr const & e) {
g_congr_cache->insert(mk_pair(key2, lemma)); g_congr_cache->insert(mk_pair(key2, lemma));
return lemma; return lemma;
} }
}
/* cache failure */ /* cache failure */
g_congr_cache->insert(mk_pair(key1, optional<ext_congr_lemma>())); g_congr_cache->insert(mk_pair(key1, optional<ext_congr_lemma>()));
return optional<ext_congr_lemma>(); return optional<ext_congr_lemma>();
@ -496,7 +566,6 @@ int congruence_closure::congr_key_cmp::operator()(congr_key const & k1, congr_ke
lean_assert(*it1); lean_assert(*it2); lean_assert(*it1); lean_assert(*it2);
switch (head(*it2)) { switch (head(*it2)) {
case congr_arg_kind::HEq: case congr_arg_kind::HEq:
lean_unreachable();
case congr_arg_kind::Eq: case congr_arg_kind::Eq:
lean_assert(head(*it1)); lean_assert(head(*it1));
r = g_cc->compare_root(*head(*it1), args1[i], args2[i]); r = g_cc->compare_root(*head(*it1), args1[i], args2[i]);
@ -560,7 +629,6 @@ auto congruence_closure::mk_congr_key(ext_congr_lemma const & lemma, expr const
lean_assert(*it1); lean_assert(*it2); lean_assert(*it1); lean_assert(*it2);
switch (head(*it2)) { switch (head(*it2)) {
case congr_arg_kind::HEq: case congr_arg_kind::HEq:
lean_unreachable();
case congr_arg_kind::Eq: case congr_arg_kind::Eq:
lean_assert(head(*it1)); lean_assert(head(*it1));
h = hash(h, get_root(*head(*it1), args[i]).hash()); h = hash(h, get_root(*head(*it1), args[i]).hash());
@ -601,7 +669,8 @@ void congruence_closure::check_iff_true(congr_key const & k) {
if (lhs != rhs) if (lhs != rhs)
return; return;
// Add e <-> true // Add e <-> true
push_todo(get_iff_name(), e, mk_true(), *g_iff_true_mark); bool heq_proof = false;
push_todo(get_iff_name(), e, mk_true(), *g_iff_true_mark, heq_proof);
} }
void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, expr const & e) { void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, expr const & e) {
@ -615,7 +684,8 @@ void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, exp
new_entry.m_cg_root = old_k->m_expr; new_entry.m_cg_root = old_k->m_expr;
m_entries.insert(k, new_entry); m_entries.insert(k, new_entry);
// 2. Put new equivalence in the TODO queue // 2. Put new equivalence in the TODO queue
push_todo(lemma.m_R, e, old_k->m_expr, *g_congr_mark); bool heq_proof = false; // TODO(Leo): fix this
push_todo(lemma.m_R, e, old_k->m_expr, *g_congr_mark, heq_proof);
} else { } else {
m_congruences.insert(k); m_congruences.insert(k);
} }
@ -635,8 +705,11 @@ static bool is_logical_app(expr const & n) {
(const_name(fn) == get_ite_name() && is_prop(n))); (const_name(fn) == get_ite_name() && is_prop(n)));
} }
void congruence_closure::internalize_core(name const & R, expr const & e, bool toplevel, bool to_propagate) { void congruence_closure::internalize_core(name R, expr const & e, bool toplevel, bool to_propagate) {
lean_assert(closed(e)); lean_assert(closed(e));
if (g_heq_based && R == get_heq_name())
R = get_eq_name();
// we allow metavariables after partitions have been frozen // we allow metavariables after partitions have been frozen
if (has_expr_metavar(e) && !m_froze_partitions) if (has_expr_metavar(e) && !m_froze_partitions)
return; return;
@ -804,7 +877,8 @@ void congruence_closure::propagate_no_confusion_eq(expr const & e1, expr const &
/* Remark: If added_prop is not none, then it contains the proposition provided to ::add. /* Remark: If added_prop is not none, then it contains the proposition provided to ::add.
We use it here to avoid an unnecessary propagation back to the current_state. */ We use it here to avoid an unnecessary propagation back to the current_state. */
void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr const & H, optional<expr> const & added_prop) { void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr const & H, optional<expr> const & added_prop,
bool heq_proof) {
auto n1 = m_entries.find(eqc_key(R, e1)); auto n1 = m_entries.find(eqc_key(R, e1));
auto n2 = m_entries.find(eqc_key(R, e2)); auto n2 = m_entries.find(eqc_key(R, e2));
if (!n1 || !n2) if (!n1 || !n2)
@ -888,6 +962,8 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
new_r1.m_next = r2->m_next; new_r1.m_next = r2->m_next;
new_r2.m_next = r1->m_next; new_r2.m_next = r1->m_next;
new_r2.m_size += r1->m_size; new_r2.m_size += r1->m_size;
if (heq_proof)
new_r2.m_heq_proofs = true;
m_entries.insert(eqc_key(R, e1_root), new_r1); m_entries.insert(eqc_key(R, e1_root), new_r1);
m_entries.insert(eqc_key(R, e2_root), new_r2); m_entries.insert(eqc_key(R, e2_root), new_r2);
lean_assert(check_invariant()); lean_assert(check_invariant());
@ -916,7 +992,8 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
bool propagate_back = false; bool propagate_back = false;
mk_entry(R2, e1, propagate_back); mk_entry(R2, e1, propagate_back);
mk_entry(R2, e2, propagate_back); mk_entry(R2, e2, propagate_back);
push_todo(R2, e1, e2, *g_lift_mark); bool heq_proof = false;
push_todo(R2, e1, e2, *g_lift_mark, heq_proof);
} }
} }
} }
@ -957,15 +1034,16 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
void congruence_closure::process_todo(optional<expr> const & added_prop) { void congruence_closure::process_todo(optional<expr> const & added_prop) {
auto & todo = get_todo(); auto & todo = get_todo();
while (!todo.empty()) { while (!todo.empty()) {
name R; expr lhs, rhs, H; name R; expr lhs, rhs, H; bool heq_proof;
std::tie(R, lhs, rhs, H) = todo.back(); std::tie(R, lhs, rhs, H, heq_proof) = todo.back();
todo.pop_back(); todo.pop_back();
add_eqv_step(R, lhs, rhs, H, added_prop); add_eqv_step(R, lhs, rhs, H, added_prop, heq_proof);
} }
} }
void congruence_closure::add_eqv_core(name const & R, expr const & lhs, expr const & rhs, expr const & H, optional<expr> const & added_prop) { void congruence_closure::add_eqv_core(name const & R, expr const & lhs, expr const & rhs, expr const & H,
push_todo(R, lhs, rhs, H); optional<expr> const & added_prop, bool heq_proof) {
push_todo(R, lhs, rhs, H, heq_proof);
process_todo(added_prop); process_todo(added_prop);
} }
@ -975,9 +1053,15 @@ void congruence_closure::add_eqv(name const & R, expr const & lhs, expr const &
flet<congruence_closure *> set_cc(g_cc, this); flet<congruence_closure *> set_cc(g_cc, this);
clear_todo(); clear_todo();
bool toplevel = false; bool to_propagate = false; bool toplevel = false; bool to_propagate = false;
internalize_core(R, lhs, toplevel, to_propagate); bool heq_proof = false;
internalize_core(R, rhs, toplevel, to_propagate); name _R = R;
add_eqv_core(R, lhs, rhs, H, none_expr()); if (g_heq_based && R == get_heq_name()) {
heq_proof = true;
_R = get_eq_name();
}
internalize_core(_R, lhs, toplevel, to_propagate);
internalize_core(_R, rhs, toplevel, to_propagate);
add_eqv_core(_R, lhs, rhs, H, none_expr(), heq_proof);
} }
void congruence_closure::add(hypothesis_idx hidx) { void congruence_closure::add(hypothesis_idx hidx) {
@ -1013,30 +1097,53 @@ void congruence_closure::add(expr const & type, expr const & proof) {
if (is_equivalence_relation_app(p, R, lhs, rhs)) { if (is_equivalence_relation_app(p, R, lhs, rhs)) {
if (is_neg) { if (is_neg) {
bool toplevel = true; bool to_propagate = false; bool toplevel = true; bool to_propagate = false;
bool heq_proof = false;
internalize_core(get_iff_name(), p, toplevel, to_propagate); internalize_core(get_iff_name(), p, toplevel, to_propagate);
add_eqv_core(get_iff_name(), p, mk_false(), mk_iff_false_intro(proof), some_expr(type)); add_eqv_core(get_iff_name(), p, mk_false(), mk_iff_false_intro(proof), some_expr(type), heq_proof);
} else { } else {
bool toplevel = false; bool to_propagate = false; bool toplevel = false; bool to_propagate = false;
bool heq_proof = false;
if (g_heq_based && R == get_heq_name()) {
/* When heterogeneous equality support is enabled, we
store equality and heterogeneous equality are stored in the same equivalence classes. */
R = get_eq_name();
heq_proof = true;
}
internalize_core(R, lhs, toplevel, to_propagate); internalize_core(R, lhs, toplevel, to_propagate);
internalize_core(R, rhs, toplevel, to_propagate); internalize_core(R, rhs, toplevel, to_propagate);
add_eqv_core(R, lhs, rhs, proof, some_expr(type)); add_eqv_core(R, lhs, rhs, proof, some_expr(type), heq_proof);
} }
} else if (is_prop(p)) { } else if (is_prop(p)) {
bool toplevel = true; bool to_propagate = false; bool toplevel = true; bool to_propagate = false;
bool heq_proof = false;
internalize_core(get_iff_name(), p, toplevel, to_propagate); internalize_core(get_iff_name(), p, toplevel, to_propagate);
if (is_neg) { if (is_neg) {
add_eqv_core(get_iff_name(), p, mk_false(), mk_iff_false_intro(proof), some_expr(type)); add_eqv_core(get_iff_name(), p, mk_false(), mk_iff_false_intro(proof), some_expr(type), heq_proof);
} else { } else {
add_eqv_core(get_iff_name(), p, mk_true(), mk_iff_true_intro(proof), some_expr(type)); add_eqv_core(get_iff_name(), p, mk_true(), mk_iff_true_intro(proof), some_expr(type), heq_proof);
} }
} }
} }
bool congruence_closure::has_heq_proofs(expr const & root) const {
lean_assert(m_entries.find(eqc_key(get_eq_name(), root)));
lean_assert(m_entries.find(eqc_key(get_eq_name(), root))->m_root == root);
return m_entries.find(eqc_key(get_eq_name(), root))->m_heq_proofs;
}
bool congruence_closure::is_eqv(name const & R, expr const & e1, expr const & e2) const { bool congruence_closure::is_eqv(name const & R, expr const & e1, expr const & e2) const {
auto n1 = m_entries.find(eqc_key(R, e1)); name R_norm = R;
if (g_heq_based && R == get_heq_name()) {
R_norm = get_eq_name();
}
auto n1 = m_entries.find(eqc_key(R_norm, e1));
if (!n1) return false; if (!n1) return false;
auto n2 = m_entries.find(eqc_key(R, e2)); auto n2 = m_entries.find(eqc_key(R_norm, e2));
if (!n2) return false; if (!n2) return false;
/* Remark: this method assumes that is_eqv is invoked with type correct parameters.
An eq class may contain equality and heterogeneous equality proofs when g_heq_based
is enabled. When this happens, the answer is correct only if e1 and e2 have the same type.
*/
return n1->m_root == n2->m_root; return n1->m_root == n2->m_root;
} }
@ -1168,11 +1275,20 @@ expr congruence_closure::mk_proof(name const & R, expr const & lhs, expr const &
} }
} }
static expr flip_proof(name const & R, expr const & H, bool flipped) { static expr flip_proof(name const & R, expr const & H, bool flipped, bool has_heq_proofs) {
if (!flipped || H == *g_congr_mark || H == *g_iff_true_mark || H == *g_lift_mark) { if (H == *g_congr_mark || H == *g_iff_true_mark || H == *g_lift_mark) {
return H; return H;
} else { } else {
return get_app_builder().mk_symm(R, H); auto & b = get_app_builder();
expr new_H = H;
if (has_heq_proofs && is_eq(relaxed_whnf(infer_type(new_H)))) {
new_H = b.mk_heq_of_eq(new_H);
}
if (!flipped) {
return new_H;
} else {
return b.mk_symm(R, new_H);
}
} }
} }
@ -1182,26 +1298,33 @@ static expr mk_trans(name const & R, optional<expr> const & H1, expr const & H2)
optional<expr> congruence_closure::get_eqv_proof(name const & R, expr const & e1, expr const & e2) const { optional<expr> congruence_closure::get_eqv_proof(name const & R, expr const & e1, expr const & e2) const {
app_builder & b = get_app_builder(); app_builder & b = get_app_builder();
name R_key = R; // We use R_key to access the equivalence class data
if (g_heq_based && R == get_heq_name()) {
R_key = get_eq_name();
}
if (has_expr_metavar(e1) || has_expr_metavar(e2)) return none_expr(); if (has_expr_metavar(e1) || has_expr_metavar(e2)) return none_expr();
if (is_def_eq(e1, e2)) if (is_def_eq(e1, e2))
return some_expr(b.lift_from_eq(R, b.mk_eq_refl(e1))); return some_expr(b.lift_from_eq(R, b.mk_eq_refl(e1)));
auto n1 = m_entries.find(eqc_key(R, e1)); auto n1 = m_entries.find(eqc_key(R_key, e1));
if (!n1) return none_expr(); if (!n1) return none_expr();
auto n2 = m_entries.find(eqc_key(R, e2)); auto n2 = m_entries.find(eqc_key(R_key, e2));
if (!n2) return none_expr(); if (!n2) return none_expr();
if (n1->m_root != n2->m_root) return none_expr(); if (n1->m_root != n2->m_root) return none_expr();
bool heq_proofs = R_key == get_eq_name() && has_heq_proofs(n1->m_root);
// R_trans is the relation we use to build the transitivity proofs
name R_trans = heq_proofs ? get_heq_name() : R_key;
// 1. Retrieve "path" from e1 to root // 1. Retrieve "path" from e1 to root
buffer<expr> path1, Hs1; buffer<expr> path1, Hs1;
rb_tree<expr, expr_quick_cmp> visited; rb_tree<expr, expr_quick_cmp> visited;
expr it1 = e1; expr it1 = e1;
while (true) { while (true) {
visited.insert(it1); visited.insert(it1);
auto it1_n = m_entries.find(eqc_key(R, it1)); auto it1_n = m_entries.find(eqc_key(R_key, it1));
lean_assert(it1_n); lean_assert(it1_n);
if (!it1_n->m_target) if (!it1_n->m_target)
break; break;
path1.push_back(*it1_n->m_target); path1.push_back(*it1_n->m_target);
Hs1.push_back(flip_proof(R, *it1_n->m_proof, it1_n->m_flipped)); Hs1.push_back(flip_proof(R_trans, *it1_n->m_proof, it1_n->m_flipped, heq_proofs));
it1 = *it1_n->m_target; it1 = *it1_n->m_target;
} }
lean_assert(it1 == n1->m_root); lean_assert(it1 == n1->m_root);
@ -1212,11 +1335,11 @@ optional<expr> congruence_closure::get_eqv_proof(name const & R, expr const & e1
while (true) { while (true) {
if (visited.contains(it2)) if (visited.contains(it2))
break; // found common break; // found common
auto it2_n = m_entries.find(eqc_key(R, it2)); auto it2_n = m_entries.find(eqc_key(R_key, it2));
lean_assert(it2_n); lean_assert(it2_n);
lean_assert(it2_n->m_target); lean_assert(it2_n->m_target);
path2.push_back(it2); path2.push_back(it2);
Hs2.push_back(flip_proof(R, *it2_n->m_proof, !it2_n->m_flipped)); Hs2.push_back(flip_proof(R_trans, *it2_n->m_proof, !it2_n->m_flipped, heq_proofs));
it2 = *it2_n->m_target; it2 = *it2_n->m_target;
} }
// it2 is the common element... // it2 is the common element...
@ -1238,16 +1361,20 @@ optional<expr> congruence_closure::get_eqv_proof(name const & R, expr const & e1
optional<expr> pr; optional<expr> pr;
expr lhs = e1; expr lhs = e1;
for (unsigned i = 0; i < path1.size(); i++) { for (unsigned i = 0; i < path1.size(); i++) {
pr = mk_trans(R, pr, mk_proof(R, lhs, path1[i], Hs1[i])); pr = mk_trans(R_trans, pr, mk_proof(R, lhs, path1[i], Hs1[i]));
lhs = path1[i]; lhs = path1[i];
} }
unsigned i = Hs2.size(); unsigned i = Hs2.size();
while (i > 0) { while (i > 0) {
--i; --i;
pr = mk_trans(R, pr, mk_proof(R, lhs, path2[i], Hs2[i])); pr = mk_trans(R_trans, pr, mk_proof(R, lhs, path2[i], Hs2[i]));
lhs = path2[i]; lhs = path2[i];
} }
lean_assert(pr); lean_assert(pr);
if (heq_proofs && R == get_eq_name())
pr = b.mk_eq_of_heq(*pr);
else if (!heq_proofs && R == get_heq_name())
pr = b.mk_heq_of_eq(*pr);
return pr; return pr;
} }

View file

@ -48,6 +48,10 @@ class congruence_closure {
unsigned m_to_propagate:1; // must be propagated back to state when in equivalence class containing true/false unsigned m_to_propagate:1; // must be propagated back to state when in equivalence class containing true/false
unsigned m_interpreted:1; // true if the node should be viewed as an abstract value unsigned m_interpreted:1; // true if the node should be viewed as an abstract value
unsigned m_constructor:1; // true if head symbol is a constructor unsigned m_constructor:1; // true if head symbol is a constructor
/* m_heq_proofs == true iff some proofs in the equivalence class are based on heterogeneous equality.
This flag is only used when option blast.cc.heq is set to true.
Moreover, we represent equality and heterogeneous equality in a single equivalence class. */
unsigned m_heq_proofs:1;
unsigned m_size; // number of elements in the equivalence class, it is meaningless if 'e' != m_root unsigned m_size; // number of elements in the equivalence class, it is meaningless if 'e' != m_root
unsigned m_mt; unsigned m_mt;
// The field m_mt is used to implement the mod-time optimization introduce by the Simplify theorem prover. // The field m_mt is used to implement the mod-time optimization introduce by the Simplify theorem prover.
@ -117,7 +121,7 @@ class congruence_closure {
void update_non_eq_relations(name const & R); void update_non_eq_relations(name const & R);
void register_to_propagate(expr const & e); void register_to_propagate(expr const & e);
void internalize_core(name const & R, expr const & e, bool toplevel, bool to_propagate); void internalize_core(name R, expr const & e, bool toplevel, bool to_propagate);
void process_todo(optional<expr> const & added_prop); void process_todo(optional<expr> const & added_prop);
int compare_symm(name const & R, expr lhs1, expr rhs1, expr lhs2, expr rhs2) const; int compare_symm(name const & R, expr lhs1, expr rhs1, expr lhs2, expr rhs2) const;
@ -138,14 +142,16 @@ class congruence_closure {
void update_mt(name const & R, expr const & e); void update_mt(name const & R, expr const & e);
expr mk_iff_false_intro(expr const & proof); expr mk_iff_false_intro(expr const & proof);
expr mk_iff_true_intro(expr const & proof); expr mk_iff_true_intro(expr const & proof);
void add_eqv_step(name const & R, expr e1, expr e2, expr const & H, optional<expr> const & added_prop); void add_eqv_step(name const & R, expr e1, expr e2, expr const & H, optional<expr> const & added_prop, bool heq_proof);
void add_eqv_core(name const & R, expr const & lhs, expr const & rhs, expr const & H, optional<expr> const & added_prop); void add_eqv_core(name const & R, expr const & lhs, expr const & rhs, expr const & H, optional<expr> const & added_prop, bool heq_proof);
void propagate_no_confusion_eq(expr const & e1, expr const & e2); void propagate_no_confusion_eq(expr const & e1, expr const & e2);
expr mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs) const; expr mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs) const;
expr mk_congr_proof(name const & R, expr const & lhs, expr const & rhs) const; expr mk_congr_proof(name const & R, expr const & lhs, expr const & rhs) const;
expr mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H) const; expr mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H) const;
bool has_heq_proofs(expr const & root) const;
void trace_eqc(name const & R, expr const & e) const; void trace_eqc(name const & R, expr const & e) const;
public: public:
void initialize(); void initialize();
@ -255,8 +261,10 @@ struct ext_congr_lemma {
/* If m_fixed_fun is false, then we build equivalences for functions, and use generic congr lemma, and ignore m_congr_lemma. /* If m_fixed_fun is false, then we build equivalences for functions, and use generic congr lemma, and ignore m_congr_lemma.
That is, even the function can be treated as an Eq argument. */ That is, even the function can be treated as an Eq argument. */
unsigned m_fixed_fun:1; unsigned m_fixed_fun:1;
/* If m_uses_heq is true, then lemma is based on heterogeneous equality. */
unsigned m_heq_based:1;
ext_congr_lemma(congr_lemma const & H); ext_congr_lemma(congr_lemma const & H);
ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed); ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed, bool heq_based);
ext_congr_lemma(name const & R, congr_lemma const & H, list<optional<name>> const & rel_names, bool lift_needed); ext_congr_lemma(name const & R, congr_lemma const & H, list<optional<name>> const & rel_names, bool lift_needed);
name const & get_relation() const { return m_R; } name const & get_relation() const { return m_R; }

View file

@ -0,0 +1,32 @@
set_option blast.strategy "cc"
set_option blast.cc.heq true -- make sure heterogeneous congruence lemmas are enabled
example (a b c : Prop) : a = b → b = c → (a ↔ c) :=
by blast
example (a b c : Prop) : a = b → b == c → (a ↔ c) :=
by blast
example (a b c : nat) : a == b → b = c → a == c :=
by blast
example (a b c : nat) : a == b → b = c → a = c :=
by blast
example (a b c d : nat) : a == b → b == c → c == d → a = d :=
by blast
example (a b c d : nat) : a == b → b = c → c == d → a = d :=
by blast
example (a b c : Prop) : a = b → b = c → (a ↔ c) :=
by blast
example (a b c : Prop) : a == b → b = c → (a ↔ c) :=
by blast
example (a b c d : Prop) : a == b → b == c → c == d → (a ↔ d) :=
by blast
definition foo (a b c d : Prop) : a == b → b = c → c == d → (a ↔ d) :=
by blast

View file

@ -0,0 +1,11 @@
set_option blast.strategy "cc"
set_option blast.cc.heq true -- make sure heterogeneous congruence lemmas are enabled
example (a b c : nat) (f : nat → nat) : a == b → b = c → f a == f c :=
by blast
example (a b c : nat) (f : nat → nat) : a == b → b = c → f a = f c :=
by blast
example (a b c d : nat) (f : nat → nat) : a == b → b = c → c == f d → f a = f (f d) :=
by blast