feat(library/blast): add union-find datastructure
This commit is contained in:
parent
9649d540c0
commit
fa70930ef4
4 changed files with 296 additions and 0 deletions
|
@ -340,6 +340,7 @@ add_subdirectory(tests/util/numerics)
|
|||
add_subdirectory(tests/util/interval)
|
||||
add_subdirectory(tests/kernel)
|
||||
add_subdirectory(tests/library)
|
||||
add_subdirectory(tests/library/blast)
|
||||
add_subdirectory(tests/frontends/lean)
|
||||
|
||||
# Include style check
|
||||
|
|
232
src/library/blast/union_find.h
Normal file
232
src/library/blast/union_find.h
Normal file
|
@ -0,0 +1,232 @@
|
|||
/*
|
||||
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#pragma once
|
||||
#include "util/rb_map.h"
|
||||
#include "util/optional.h"
|
||||
|
||||
namespace lean {
|
||||
/** \brief (template for) Union-find datastructure that "explains" implied equalities.
|
||||
We use functional datastructures to be able to have a O(1) copy operation.
|
||||
|
||||
Each join/union is decorated with a justification.
|
||||
|
||||
\c cmp implements a total order on \c node. That is, it provides the method:
|
||||
int operator()(node const & n1, node const & n2) const
|
||||
s.t. the result is negative when n1 < n2, 0 if n1 == n2, and positive if n1 > n2.
|
||||
|
||||
The implementation also provides a method to traverse the elements of an equivalence
|
||||
class. The implementation is based on a datastructure used in the Simplify theorem prover.
|
||||
|
||||
Since it provides extra functionality, it does not implement the O(n*alpha(n)) amortized time
|
||||
per operation algorithm.
|
||||
*/
|
||||
template<typename node, typename jst, typename cmp>
|
||||
class union_find : private cmp {
|
||||
rb_map<node, node, cmp> m_root;
|
||||
rb_map<node, node, cmp> m_next;
|
||||
rb_map<node, unsigned, cmp> m_rank;
|
||||
rb_map<node, pair<node, jst>, cmp> m_jst;
|
||||
|
||||
bool is_equal(node const & n1, node const & n2) const {
|
||||
return cmp::operator()(n1, n2) == 0;
|
||||
}
|
||||
|
||||
unsigned rank(node const & n) const {
|
||||
if (auto r = m_rank.find(n))
|
||||
return *r;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
void set_rank(node const & n, unsigned r) { m_rank.insert(n, r); }
|
||||
|
||||
node const & root(node const & n) const {
|
||||
if (auto r = m_root.find(n))
|
||||
return *r;
|
||||
else
|
||||
return n;
|
||||
}
|
||||
void set_root(node const & n, node const & r) { m_root.insert(n, r); }
|
||||
|
||||
node const & next(node const & n) const {
|
||||
if (auto r = m_next.find(n))
|
||||
return *r;
|
||||
else
|
||||
return n;
|
||||
}
|
||||
void set_next(node const & n, node const & nx) { m_next.insert(n, nx); }
|
||||
void set_justification(node const & n, node const & t, jst const & j) { m_jst.insert(n, mk_pair(t, j)); }
|
||||
|
||||
// for debugging purposes only
|
||||
bool check_inv(node const & n) const {
|
||||
node r = root(n);
|
||||
unsigned sz = size(r);
|
||||
node it = n;
|
||||
do {
|
||||
lean_assert_eq(root(it), r);
|
||||
lean_assert(reaches(it, r));
|
||||
lean_assert(size(it), sz);
|
||||
it = next(it);
|
||||
} while (!is_equal(it, n));
|
||||
return true;
|
||||
}
|
||||
|
||||
void join_core(node const & n1, node r1, node const & n2, node r2, jst const & j) {
|
||||
// r1 will be the root of the resulting equivalence class.
|
||||
DEBUG_CODE(unsigned sz1 = size(n1); unsigned sz2 = size(n2););
|
||||
// Step 1) update m_jst
|
||||
//
|
||||
// Given justification paths
|
||||
// n1 -> ... -> r1
|
||||
// n2 -> ... -> r2
|
||||
// we generate the path
|
||||
// r2 -> ... -> n2 -> n1 -> ... -> r1
|
||||
buffer<pair<node, jst>> trace;
|
||||
node it2 = n2;
|
||||
while (pair<node, jst> const * p = m_jst.find(it2)) {
|
||||
trace.push_back(*p);
|
||||
it2 = p->first;
|
||||
}
|
||||
lean_assert(is_equal(it2, r2));
|
||||
unsigned i = trace.size();
|
||||
while (i > 1) {
|
||||
--i;
|
||||
set_justification(trace[i].first, trace[i-1].first, trace[i].second);
|
||||
}
|
||||
if (i > 0) {
|
||||
set_justification(trace[0].first, n2, trace[0].second);
|
||||
}
|
||||
set_justification(n2, n1, j);
|
||||
|
||||
// Step 2) update m_root of nodes in n2 equivalence class to r1
|
||||
it2 = n2;
|
||||
do {
|
||||
set_root(it2, r1);
|
||||
it2 = next(it2);
|
||||
} while (!is_equal(it2, n2));
|
||||
|
||||
// Step 3) update m_next of r1 and r2
|
||||
node next1 = next(r1);
|
||||
node next2 = next(r2);
|
||||
set_next(r1, next2);
|
||||
set_next(r2, next1);
|
||||
|
||||
lean_assert(check_inv(r1));
|
||||
lean_assert_eq(size(n1), sz1 + sz2);
|
||||
}
|
||||
|
||||
/** \brief Return true if \c s reaches \c r by following m_jst edges */
|
||||
bool reaches(node const & s, node const & r) const {
|
||||
node it = s;
|
||||
while (true) {
|
||||
if (is_equal(it, r))
|
||||
return true;
|
||||
pair<node, jst> const * p = m_jst.find(it);
|
||||
if (p) {
|
||||
it = p->first;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void explain_core(node const & n1, node const & n2, node const & r, buffer<jst> & js) const {
|
||||
lean_assert(is_equal(root(n1), r));
|
||||
lean_assert(is_equal(root(n2), r));
|
||||
node it1 = n1;
|
||||
while (true) {
|
||||
if (reaches(n2, it1)) {
|
||||
// it is the common in the paths n1 -> r and n2 -> r
|
||||
node it2 = n2;
|
||||
unsigned sz1 = js.size();
|
||||
while (true) {
|
||||
if (is_equal(it2, it1)) {
|
||||
std::reverse(js.begin() + sz1, js.end());
|
||||
return;
|
||||
}
|
||||
pair<node, jst> const * p = m_jst.find(it2);
|
||||
lean_assert(p);
|
||||
js.push_back(p->second);
|
||||
it2 = p->first;
|
||||
}
|
||||
} else {
|
||||
pair<node, jst> const * p = m_jst.find(it1);
|
||||
lean_assert(p);
|
||||
js.push_back(p->second);
|
||||
it1 = p->first;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
union_find(cmp const & c = cmp()):cmp(c) {}
|
||||
|
||||
/** \brief Merge the equivalence class of \c n1 with \c n2 using justification \c j. */
|
||||
void join(node const & n1, node const & n2, jst const & j) {
|
||||
node const & r1 = root(n1);
|
||||
node const & r2 = root(n2);
|
||||
if (is_equal(r1, r2))
|
||||
return;
|
||||
unsigned k1 = rank(n1);
|
||||
unsigned k2 = rank(n2);
|
||||
if (k1 > k2) {
|
||||
join_core(n1, r1, n2, r2, j);
|
||||
} else if (k1 == k2) {
|
||||
join_core(n1, r1, n2, r2, j);
|
||||
set_rank(n1, k1+1);
|
||||
} else {
|
||||
join_core(n2, r2, n1, r1, j);
|
||||
}
|
||||
}
|
||||
|
||||
/** \brief Return the size of the equivalence class containing \c n */
|
||||
unsigned size(node const & n) const {
|
||||
unsigned r = 0;
|
||||
node it = n;
|
||||
do {
|
||||
lean_assert(is_eq(it, n));
|
||||
r++;
|
||||
it = next(it);
|
||||
} while (!is_equal(it, n));
|
||||
return r;
|
||||
}
|
||||
|
||||
/** \brief Return the representative for the equivalence class containing \c n. */
|
||||
node rep(node const & n) const { return root(n); }
|
||||
|
||||
/** \brief Return true iff \c n1 and \c n2 are in the same equivalence class. */
|
||||
bool is_eq(node const & n1, node const & n2) const { return is_equal(rep(n1), rep(n2)); }
|
||||
|
||||
/** \brief For each node \c m in the equivalence class of \c n, execute <tt>f(m)</tt> */
|
||||
template<typename F>
|
||||
void for_each(node const & n, F f) const {
|
||||
node it = n;
|
||||
do {
|
||||
lean_assert(is_eq(it, n));
|
||||
f(it);
|
||||
it = next(it);
|
||||
} while (!is_equal(it, n));
|
||||
}
|
||||
|
||||
/** \brief If is_eq(n1, n2), then return true and store the justifications that can be used to produce
|
||||
a transitivity+symmetry proof for n1 = n2 */
|
||||
bool explain(node const & n1, node const & n2, buffer<jst> & js) const {
|
||||
node r1 = root(n1);
|
||||
node r2 = root(n2);
|
||||
if (is_equal(r1, r2)) {
|
||||
if (rank(r1) >= rank(r2)) {
|
||||
explain_core(n1, n2, r1, js);
|
||||
} else {
|
||||
explain_core(n2, n1, r1, js);
|
||||
std::reverse(js.begin(), js.end());
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
3
src/tests/library/blast/CMakeLists.txt
Normal file
3
src/tests/library/blast/CMakeLists.txt
Normal file
|
@ -0,0 +1,3 @@
|
|||
add_executable(union_find union_find.cpp)
|
||||
target_link_libraries(union_find "util" ${EXTRA_LIBS})
|
||||
add_test(union_find "${CMAKE_CURRENT_BINARY_DIR}/union_find")
|
60
src/tests/library/blast/union_find.cpp
Normal file
60
src/tests/library/blast/union_find.cpp
Normal file
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include "util/test.h"
|
||||
#include "util/init_module.h"
|
||||
#include "library/blast/union_find.h"
|
||||
using namespace lean;
|
||||
|
||||
typedef union_find<unsigned, unsigned, unsigned_cmp> uf;
|
||||
|
||||
static void check_explain(uf const & m, unsigned n1, unsigned n2, std::initializer_list<unsigned> const & expected_js) {
|
||||
buffer<unsigned> js1;
|
||||
bool r = m.explain(n1, n2, js1);
|
||||
lean_assert(r);
|
||||
lean_assert(m.rep(n1) == m.rep(n2));
|
||||
std::sort(js1.begin(), js1.end());
|
||||
buffer<unsigned> js2;
|
||||
js2.append(expected_js.size(), expected_js.begin());
|
||||
std::sort(js2.begin(), js2.end());
|
||||
lean_assert(js1.size() == js2.size());
|
||||
for (unsigned i = 0; i < js1.size(); i++) {
|
||||
lean_assert(js1[i] == js2[i]);
|
||||
}
|
||||
}
|
||||
|
||||
static void tst1() {
|
||||
uf m;
|
||||
m.join(1, 2, 0);
|
||||
lean_assert(m.is_eq(1, 1));
|
||||
lean_assert(m.is_eq(1, 2));
|
||||
m.join(1, 3, 1);
|
||||
lean_assert(m.is_eq(2, 3));
|
||||
check_explain(m, 2, 3, {0, 1});
|
||||
check_explain(m, 2, 1, {0});
|
||||
check_explain(m, 1, 3, {1});
|
||||
m.join(3, 4, 2);
|
||||
m.join(5, 1, 3);
|
||||
m.join(6, 2, 4);
|
||||
lean_assert(m.rep(6) == m.rep(4));
|
||||
check_explain(m, 2, 3, {0, 1});
|
||||
check_explain(m, 6, 4, {0, 1, 2, 4});
|
||||
check_explain(m, 5, 6, {0, 3, 4});
|
||||
lean_assert_eq(m.size(1), 6);
|
||||
|
||||
for (unsigned i = 10; i < 30; i++)
|
||||
m.join(i, i+1, i);
|
||||
check_explain(m, 10, 15, {10, 11, 12, 13, 14});
|
||||
lean_assert_eq(m.size(10), 21);
|
||||
}
|
||||
|
||||
int main() {
|
||||
save_stack_info();
|
||||
initialize_util_module();
|
||||
tst1();
|
||||
finalize_util_module();
|
||||
return has_violations() ? 1 : 0;
|
||||
}
|
Loading…
Reference in a new issue