feat(library/blast/congruence_closure): lift equalities

This commit is contained in:
Leonardo de Moura 2015-11-20 13:46:29 -08:00
parent 28970ef717
commit 8f368cebbf
4 changed files with 127 additions and 47 deletions

View file

@ -71,9 +71,28 @@ struct ext_congr_lemma {
/* We use the following cache for user-defined lemmas and automatically generated ones. */
typedef std::unordered_map<congr_lemma_key, optional<ext_congr_lemma>, congr_lemma_key_hash_fn, congr_lemma_key_eq_fn> congr_cache;
typedef std::tuple<name, expr, expr, expr> cc_todo_entry;
static expr * g_congr_mark = nullptr; // dummy congruence proof, it is just a placeholder.
static expr * g_iff_true_mark = nullptr; // dummy iff_true proof, it is just a placeholder.
static expr * g_lift_mark = nullptr; // dummy lift eq proof, it is just a placeholder.
/* Small hack for not storing a pointer to the congruence_closure object
at congruence_closure::congr_key_cmp */
LEAN_THREAD_PTR(congruence_closure, g_cc);
LEAN_THREAD_VALUE(congr_cache *, g_congr_cache, nullptr);
MK_THREAD_LOCAL_GET_DEF(std::vector<cc_todo_entry>, get_todo);
static void clear_todo() {
get_todo().clear();
}
static void push_todo(name const & R, expr const & lhs, expr const & rhs, expr const & H) {
get_todo().emplace_back(R, lhs, rhs, H);
}
scope_congruence_closure::scope_congruence_closure():
m_old_cache(g_congr_cache) {
g_congr_cache = new congr_cache();
@ -85,11 +104,11 @@ scope_congruence_closure::~scope_congruence_closure() {
}
void congruence_closure::initialize() {
mk_entry_for(get_iff_name(), mk_true());
mk_entry_for(get_iff_name(), mk_false());
mk_entry_core(get_iff_name(), mk_true());
mk_entry_core(get_iff_name(), mk_false());
}
void congruence_closure::mk_entry_for(name const & R, expr const & e) {
void congruence_closure::mk_entry_core(name const & R, expr const & e) {
lean_assert(!m_entries.find(eqc_key(R, e)));
entry n;
n.m_next = e;
@ -98,6 +117,29 @@ void congruence_closure::mk_entry_for(name const & R, expr const & e) {
n.m_size = 1;
n.m_flipped = false;
m_entries.insert(eqc_key(R, e), n);
if (R != get_eq_name()) {
// lift equalities to R
auto n = m_entries.find(eqc_key(get_eq_name(), e));
if (n) {
// e has an eq equivalence class
expr it = n->m_next;
while (it != e) {
if (m_entries.find(eqc_key(R, it))) {
push_todo(R, e, it, *g_lift_mark);
break;
}
auto it_n = m_entries.find(eqc_key(get_eq_name(), it));
lean_assert(it_n);
it = it_n->m_next;
}
}
}
}
void congruence_closure::mk_entry(name const & R, expr const & e) {
if (m_entries.find(eqc_key(R, e)))
return;
mk_entry_core(R, e);
}
static optional<ext_congr_lemma> mk_ext_congr_lemma_core(name const & R, expr const & fn, unsigned nargs) {
@ -153,6 +195,13 @@ static optional<ext_congr_lemma> mk_ext_congr_lemma(name const & R, expr const &
return r;
}
void congruence_closure::update_non_eq_relations(name const & R) {
if (R == get_eq_name())
return;
if (std::find(m_non_eq_relations.begin(), m_non_eq_relations.end(), R) == m_non_eq_relations.end())
m_non_eq_relations = cons(R, m_non_eq_relations);
}
void congruence_closure::add_occurrence(name const & Rp, expr const & parent, name const & Rc, expr const & child) {
child_key k(Rc, child);
parent_occ_set ps;
@ -162,10 +211,6 @@ void congruence_closure::add_occurrence(name const & Rp, expr const & parent, na
m_parents.insert(k, ps);
}
/* Small hack for not storing a pointer to the congruence_closure object
at congruence_closure::congr_key_cmp */
LEAN_THREAD_PTR(congruence_closure, g_cc);
/* Auxiliary function for comparing (lhs1 ~ rhs1) and (lhs2 ~ rhs2),
when ~ is symmetric.
It returns 0 (equal) for (a ~ b) (b ~ a) */
@ -323,17 +368,6 @@ auto congruence_closure::mk_congr_key(ext_congr_lemma const & lemma, expr const
return k;
}
static expr * g_congr_mark = nullptr; // dummy congruence proof, it is just a placeholder.
static expr * g_iff_true_mark = nullptr; // dummy iff_true proof, it is just a placeholder.
typedef std::tuple<name, expr, expr, expr> cc_todo_entry;
MK_THREAD_LOCAL_GET_DEF(std::vector<cc_todo_entry>, get_todo);
static void clear_todo() {
get_todo().clear();
}
void congruence_closure::check_iff_true(congr_key const & k) {
expr const & e = k.m_expr;
name R; expr lhs, rhs;
@ -353,7 +387,7 @@ void congruence_closure::check_iff_true(congr_key const & k) {
if (lhs != rhs)
return;
// Add e <-> true
get_todo().emplace_back(get_iff_name(), e, mk_true(), *g_iff_true_mark);
push_todo(get_iff_name(), e, mk_true(), *g_iff_true_mark);
}
void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, expr const & e) {
@ -367,7 +401,7 @@ void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, exp
new_entry.m_cg_root = old_k->m_expr;
m_entries.insert(k, new_entry);
// 2. Put new equivalence in the TODO queue
get_todo().emplace_back(lemma.m_R, e, old_k->m_expr, *g_congr_mark);
push_todo(lemma.m_R, e, old_k->m_expr, *g_congr_mark);
} else {
m_congruences.insert(k);
}
@ -380,6 +414,7 @@ void congruence_closure::internalize_core(name const & R, expr const & e) {
return;
if (m_entries.find(eqc_key(R, e)))
return; // e has already been internalized
update_non_eq_relations(R);
switch (e.kind()) {
case expr_kind::Var: case expr_kind::Meta:
lean_unreachable();
@ -387,12 +422,12 @@ void congruence_closure::internalize_core(name const & R, expr const & e) {
return;
case expr_kind::Constant: case expr_kind::Local:
case expr_kind::Lambda:
mk_entry_for(R, e);
mk_entry_core(R, e);
return;
case expr_kind::Macro:
for (unsigned i = 0; i < macro_num_args(e); i++)
internalize_core(R, macro_arg(e, i));
mk_entry_for(R, e);
mk_entry_core(R, e);
break;
case expr_kind::Pi:
if (is_arrow(e) && is_prop(binding_domain(e)) && is_prop(binding_body(e))) {
@ -400,11 +435,11 @@ void congruence_closure::internalize_core(name const & R, expr const & e) {
internalize_core(R, binding_body(e));
}
if (is_prop(e)) {
mk_entry_for(R, e);
mk_entry_core(R, e);
}
return;
case expr_kind::App: {
mk_entry_for(R, e);
mk_entry_core(R, e);
buffer<expr> args;
expr const & fn = get_app_args(e, args);
if (auto lemma = mk_ext_congr_lemma(R, fn, args.size())) {
@ -485,6 +520,8 @@ void congruence_closure::remove_parents(name const & R, expr const & e) {
void congruence_closure::reinsert_parents(name const & R, expr const & e) {
auto ps = m_parents.find(child_key(R, e));
if (!ps) return;
// TODO(Leo): consider the following optimization:
// 1- remove from ps any parent that is not a congruence root anymore
ps->for_each([&](parent_occ const & p) {
expr const & fn = get_app_fn(p.m_expr);
unsigned nargs = get_app_num_args(p.m_expr);
@ -564,7 +601,7 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
// copy e1_root parents to e2_root
child_key k1(R, e1_root);
auto ps1 = m_parents.find(k1);
if (!ps1) return; // e1_root doesn't have parents
if (ps1) {
parent_occ_set ps2;
child_key k2(R, e2_root);
if (auto it = m_parents.find(k2))
@ -574,6 +611,19 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con
});
m_parents.erase(k1);
m_parents.insert(k2, ps2);
}
// lift equivalence
if (R == get_eq_name()) {
for (name const & R2 : m_non_eq_relations) {
if (m_entries.find(eqc_key(R2, e1)) ||
m_entries.find(eqc_key(R2, e2))) {
mk_entry(R2, e1);
mk_entry(R2, e2);
push_todo(R2, e1, e2, *g_lift_mark);
}
}
}
}
void congruence_closure::process_todo() {
@ -586,8 +636,8 @@ void congruence_closure::process_todo() {
}
}
void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const & _rhs, expr const & _H) {
get_todo().emplace_back(_R, _lhs, _rhs, _H);
void congruence_closure::add_eqv(name const & R, expr const & lhs, expr const & rhs, expr const & H) {
push_todo(R, lhs, rhs, H);
process_todo();
}
@ -707,17 +757,16 @@ expr congruence_closure::mk_proof(name const & R, expr const & lhs, expr const &
} else if (H == *g_iff_true_mark) {
// TODO(Leo):
lean_unreachable();
} else if (H == *g_lift_mark) {
expr H1 = *get_eqv_proof(get_eq_name(), lhs, rhs);
return get_app_builder().lift_from_eq(R, H1);
} else {
return H;
}
}
static expr flip_proof(name const & R, expr const & H, bool flipped) {
if (!flipped) {
return H;
} else if (H == *g_congr_mark) {
return H;
} else if (H == *g_iff_true_mark) {
if (!flipped || H == *g_congr_mark || H == *g_iff_true_mark || H == *g_lift_mark) {
return H;
} else {
return get_app_builder().mk_symm(R, H);
@ -756,7 +805,6 @@ optional<expr> congruence_closure::get_eqv_proof(name const & R, expr const & e1
// 2. The path from e2 to root must have at least one element c in visited
// Retrieve "path" from e2 to c
buffer<expr> path2, Hs2;
path2.push_back(e2);
expr it2 = e2;
while (true) {
if (visited.contains(it2))
@ -772,7 +820,7 @@ optional<expr> congruence_closure::get_eqv_proof(name const & R, expr const & e1
// 3. Shink path1/Hs1 until we find it2 (the common element)
while (true) {
if (path1.empty()) {
lean_assert(it2 == e1 && e1 == n2->m_root);
lean_assert(it2 == e1);
break;
}
if (path1.back() == it2) {
@ -986,7 +1034,6 @@ bool congruence_closure::check_invariant() const {
lean_assert(check_eqc(k.m_R, k.m_expr));
}
});
// TODO(Leo):
return true;
}
@ -994,10 +1041,12 @@ void initialize_congruence_closure() {
name prefix = name::mk_internal_unique_name();
g_congr_mark = new expr(mk_constant(name(prefix, "[congruence]")));
g_iff_true_mark = new expr(mk_constant(name(prefix, "[iff-true]")));
g_lift_mark = new expr(mk_constant(name(prefix, "[lift]")));
}
void finalize_congruence_closure() {
delete g_congr_mark;
delete g_iff_true_mark;
delete g_lift_mark;
}
}}

View file

@ -98,6 +98,9 @@ class congruence_closure {
entries m_entries;
parents m_parents;
congruences m_congruences;
list<name> m_non_eq_relations;
void update_non_eq_relations(name const & R);
void internalize_core(name const & R, expr const & e);
void process_todo();
@ -108,7 +111,8 @@ class congruence_closure {
congr_key mk_congr_key(ext_congr_lemma const & lemma, expr const & e) const;
void check_iff_true(congr_key const & k);
void mk_entry_for(name const & R, expr const & e);
void mk_entry_core(name const & R, expr const & e);
void mk_entry(name const & R, expr const & e);
void add_occurrence(name const & Rp, expr const & parent, name const & Rc, expr const & child);
void add_congruence_table(ext_congr_lemma const & lemma, expr const & e);
void invert_trans(name const & R, expr const & e, optional<expr> new_target, optional<expr> new_proof);
@ -157,9 +161,9 @@ public:
/** \brief Return the proof of inconsistency */
optional<expr> get_inconsistency_proof() const;
/** \brief Return true iff 'e1' and 'e2' are in the same equivalence class for relation \c rel_name. */
/** \brief Return true iff 'e1' and 'e2' are in the same equivalence class for relation \c R. */
bool is_eqv(name const & R, expr const & e1, expr const & e2) const;
optional<expr> get_eqv_proof(name const & rel_name, expr const & e1, expr const & e2) const;
optional<expr> get_eqv_proof(name const & R, expr const & e1, expr const & e2) const;
/** \brief Return true iff `e1 ~ e2` is in the equivalence class of false for iff. */
bool is_uneqv(name const & R, expr const & e1, expr const & e2) const;

View file

@ -17,3 +17,6 @@ open perm
example (a b c d : list nat) : a ~ b → c ~ b → d ~ c → a ~ d :=
by blast
example (a b c d : list nat) : a ~ b → c ~ b → d = c → a ~ d :=
by blast

View file

@ -0,0 +1,24 @@
set_option blast.init_depth 10
set_option blast.inc_depth 100
set_option blast.subst false
example (a b c d : nat) : a == b → b = c → c == d → a == d :=
by blast
example (a b c d : nat) : a = b → b = c → c == d → a == d :=
by blast
example (a b c d : nat) : a = b → b == c → c == d → a == d :=
by blast
example (a b c d : nat) : a == b → b == c → c = d → a == d :=
by blast
example (a b c d : nat) : a == b → b = c → c = d → a == d :=
by blast
example (a b c d : nat) : a = b → b = c → c = d → a == d :=
by blast
example (a b c d : nat) : a = b → b == c → c = d → a == d :=
by blast