feat(library/blast/congruence_closure): add "add_eqv"

This commit is contained in:
Leonardo de Moura 2015-11-19 15:55:56 -08:00
parent f9ced4b3e1
commit 968b615390
2 changed files with 173 additions and 11 deletions

View file

@ -10,6 +10,7 @@ Author: Leonardo de Moura
#include "library/blast/congruence_closure.h"
#include "library/blast/util.h"
#include "library/blast/blast.h"
#include "library/blast/trace.h"
namespace lean {
namespace blast {
@ -88,7 +89,7 @@ void congruence_closure::mk_entry_for(name const & R, expr const & e) {
n.m_next = e;
n.m_root = e;
n.m_cg_root = e;
n.m_rank = 0;
n.m_size = 1;
m_entries.insert(eqc_key(R, e), n);
}
@ -159,6 +160,8 @@ void congruence_closure::internalize(name const & R, expr const & e) {
lean_assert(closed(e));
if (has_expr_metavar(e))
return;
if (m_entries.find(eqc_key(R, e)))
return; // e has already been internalized
switch (e.kind()) {
case expr_kind::Var: case expr_kind::Meta:
lean_unreachable();
@ -221,14 +224,117 @@ static void clear_todo() {
get_todo().clear();
}
void congruence_closure::add_eqv(name const & R, expr const & lhs, expr const & rhs, expr const & pr) {
/*
The fields m_target and m_proof in e's entry are encoding a transitivity proof
Let target(e) and proof(e) denote these fields.
e = target(e) : proof(e)
... = target(target(e)) : proof(target(e))
... ...
= root(e) : ...
The transitivity proof eventually reaches the root of the equivalence class.
This method "inverts" the proof. That is, the m_target goes from root(e) to e after
we execute it.
*/
void congruence_closure::invert_trans(name const & R, expr const & e, optional<expr> new_target, optional<expr> new_proof) {
eqc_key k(R, e);
auto n = m_entries.find(k);
lean_assert(n);
entry new_n = *n;
if (n->m_target)
invert_trans(R, *new_n.m_target, some_expr(e), new_n.m_proof);
new_n.m_target = new_target;
new_n.m_proof = new_proof;
m_entries.insert(k, new_n);
}
void congruence_closure::invert_trans(name const & R, expr const & e) {
invert_trans(R, e, none_expr(), none_expr());
}
void congruence_closure::remove_parents(name const & R, expr const & e) {
std::cout << R << " " << e << "\n";
// TODO(Leo):
}
void congruence_closure::insert_parents(name const & R, expr const & e) {
std::cout << R << " " << e << "\n";
// TODO(Leo):
}
void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr const & H) {
auto n1 = m_entries.find(eqc_key(R, e1));
auto n2 = m_entries.find(eqc_key(R, e2));
if (!n1 || !n2)
return;
if (n1->m_root == n2->m_root)
return; // they are already in the same equivalence class
auto r1 = m_entries.find(eqc_key(R, n1->m_root));
auto r2 = m_entries.find(eqc_key(R, n2->m_root));
lean_assert(r1 && r2);
// We want r2 to be the root of the combined class.
if (r1->m_size > r2->m_size) {
std::swap(e1, e2);
std::swap(n1, n2);
std::swap(r1, r2);
// Remark: we don't apply symmetry eagerly. So, we don't adjust H.
}
expr e1_root = n1->m_root;
expr e2_root = n2->m_root;
entry new_n1 = *n1;
// Following target/proof we have
// e1 -> ... -> r1
// e2 -> ... -> r2
// We want
// r1 -> ... -> e1 -> e2 -> ... -> r2
invert_trans(R, e1);
new_n1.m_target = e2;
new_n1.m_proof = H;
m_entries.insert(eqc_key(R, e1), new_n1);
// The hash code for the parents is going to change
remove_parents(R, e1);
// force all m_root fields in e1 equivalence class to point to e2_root
expr it = e1;
do {
auto it_n = m_entries.find(eqc_key(R, it));
lean_assert(it_n);
entry new_it_n = *it_n;
new_it_n.m_root = e2_root;
m_entries.insert(eqc_key(R, it), new_it_n);
it = new_it_n.m_next;
} while (it != e1);
insert_parents(R, e1);
// update next of e1_root and e2_root, and size of e2_root
r1 = m_entries.find(eqc_key(R, e1_root));
r2 = m_entries.find(eqc_key(R, e2_root));
lean_assert(r1 && r2);
lean_assert(r1->m_root == e2_root);
entry new_r1 = *r1;
entry new_r2 = *r2;
new_r1.m_next = r2->m_next;
new_r2.m_next = r1->m_next;
new_r2.m_size += r1->m_size;
m_entries.insert(eqc_key(R, e1_root), new_r1);
m_entries.insert(eqc_key(R, e2_root), new_r2);
lean_assert(check_invariant());
}
void congruence_closure::add_eqv(name const & _R, expr const & _lhs, expr const & _rhs, expr const & _H) {
auto & todo = get_todo();
todo.emplace_back(R, lhs, rhs, pr);
todo.emplace_back(_R, _lhs, _rhs, _H);
while (!todo.empty()) {
name R; expr lhs, rhs, pr;
std::tie(R, lhs, rhs, pr) = todo.back();
name R; expr lhs, rhs, H;
std::tie(R, lhs, rhs, H) = todo.back();
todo.pop_back();
// TODO(Leo): process
add_eqv_step(R, lhs, rhs, H);
}
}
@ -241,7 +347,7 @@ void congruence_closure::add(hypothesis_idx hidx) {
hypothesis const & h = s.get_hypothesis_decl(hidx);
try {
expr const & type = h.get_type();
expr p;
expr p = type;
bool is_neg = is_not(type, p);
if (is_neg && !is_standard(env()))
return;
@ -385,11 +491,60 @@ expr congruence_closure::get_next(name const & R, expr const & e) const {
}
}
void congruence_closure::display_eqc(name const & R, expr const & e) const {
auto out = diagnostic(env(), ios());
bool first = true;
expr it = e;
out << R << " {";
do {
auto it_n = m_entries.find(eqc_key(R, it));
if (first) first = false; else out << ", ";
out << ppb(it);
it = it_n->m_next;
} while (it != e);
out << "}";
}
void congruence_closure::display() const {
// TODO(Leo):
auto out = diagnostic(env(), ios());
m_entries.for_each([&](eqc_key const & k, entry const & n) {
if (k.m_expr == n.m_root) {
display_eqc(k.m_R, k.m_expr);
out << "\n";
}
});
}
bool congruence_closure::check_eqc(name const & R, expr const & e) const {
expr root = get_root(R, e);
unsigned size = 0;
expr it = e;
do {
auto it_n = m_entries.find(eqc_key(R, it));
lean_assert(it_n);
lean_assert(it_n->m_root == root);
auto it2 = it;
// following m_target fields should lead to root
while (true) {
auto it2_n = m_entries.find(eqc_key(R, it2));
if (!it2_n->m_target)
break;
it2 = *it2_n->m_target;
}
lean_assert(it2 == root);
it = it_n->m_next;
size++;
} while (it != e);
lean_assert(m_entries.find(eqc_key(R, root))->m_size == size);
return true;
}
bool congruence_closure::check_invariant() const {
m_entries.for_each([&](eqc_key const & k, entry const & n) {
if (k.m_expr == n.m_root) {
lean_assert(check_eqc(k.m_R, k.m_expr));
}
});
// TODO(Leo):
return true;
}

View file

@ -43,7 +43,7 @@ class congruence_closure {
// store 'target' at 'm_target', and 'H' at 'm_proof'. Both fields are none if 'e' == m_root
optional<expr> m_target;
optional<expr> m_proof;
unsigned m_rank; // rank of the equivalence class, it is meaningless if 'e' != m_root
unsigned m_size; // number of elements in the equivalence class, it is meaningless if 'e' != m_root
};
/* Key (R, e) for the mapping (R, e) -> entry */
@ -58,7 +58,7 @@ class congruence_closure {
int operator()(eqc_key const & k1, eqc_key const & k2) const {
int r = quick_cmp(k1.m_R, k2.m_R);
if (r != 0) return r;
else return is_lt(k1.m_expr, k2.m_expr, true);
else return expr_quick_cmp()(k1.m_expr, k2.m_expr);
}
};
@ -95,8 +95,14 @@ class congruence_closure {
void mk_entry_for(name const & R, expr const & e);
void add_occurrence(name const & Rp, expr const & parent, name const & Rc, expr const & child);
void add_congruence_table(ext_congr_lemma const & lemma, expr const & e);
void add_eqv(name const & R, expr const & lhs, expr const & rhs, expr const & pr);
void invert_trans(name const & R, expr const & e, optional<expr> new_target, optional<expr> new_proof);
void invert_trans(name const & R, expr const & e);
void remove_parents(name const & R, expr const & e);
void insert_parents(name const & R, expr const & e);
void add_eqv_step(name const & R, expr e1, expr e2, expr const & H);
void add_eqv(name const & R, expr const & lhs, expr const & rhs, expr const & H);
void display_eqc(name const & R, expr const & e) const;
public:
/** \brief Register expression \c e in this data-structure.
It creates entries for each sub-expression in \c e.
@ -154,6 +160,7 @@ public:
/** \brief dump for debugging purposes. */
void display() const;
bool check_eqc(name const & R, expr const & e) const;
bool check_invariant() const;
};