feat(util): add red-black trees
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
fdde12e6af
commit
528ea367ad
3 changed files with 502 additions and 0 deletions
|
@ -28,6 +28,9 @@ add_test(thread ${CMAKE_CURRENT_BINARY_DIR}/thread)
|
|||
add_executable(memory memory.cpp)
|
||||
target_link_libraries(memory ${EXTRA_LIBS})
|
||||
add_test(memory ${CMAKE_CURRENT_BINARY_DIR}/memory)
|
||||
add_executable(rb_tree rb_tree.cpp)
|
||||
target_link_libraries(rb_tree ${EXTRA_LIBS})
|
||||
add_test(rb_tree ${CMAKE_CURRENT_BINARY_DIR}/rb_tree)
|
||||
add_executable(splay_tree splay_tree.cpp)
|
||||
target_link_libraries(splay_tree ${EXTRA_LIBS})
|
||||
add_test(splay_tree ${CMAKE_CURRENT_BINARY_DIR}/splay_tree)
|
||||
|
|
190
src/tests/util/rb_tree.cpp
Normal file
190
src/tests/util/rb_tree.cpp
Normal file
|
@ -0,0 +1,190 @@
|
|||
/*
|
||||
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <ctime>
|
||||
#include <unordered_set>
|
||||
#include <sstream>
|
||||
#include "util/test.h"
|
||||
#include "util/buffer.h"
|
||||
#include "util/rb_tree.h"
|
||||
#include "util/timeit.h"
|
||||
using namespace lean;
|
||||
|
||||
struct int_lt { int operator()(int i1, int i2) const { return i1 < i2 ? -1 : (i1 > i2 ? 1 : 0); } };
|
||||
typedef rb_tree<int, int_lt> int_rb_tree;
|
||||
typedef std::unordered_set<int> int_set;
|
||||
|
||||
void print(int_rb_tree const & t) {
|
||||
std::cout << t << "\n";
|
||||
}
|
||||
|
||||
static void tst1() {
|
||||
int_rb_tree s;
|
||||
for (unsigned i = 0; i < 100; i++) {
|
||||
s.insert(i);
|
||||
}
|
||||
std::cout << s << "\n";
|
||||
std::cout << "DEPTH: " << s.get_depth() << "\n";
|
||||
int_rb_tree s2 = s;
|
||||
std::cout << "DEPTH: " << s2.get_depth() << "\n";
|
||||
s2.insert(200);
|
||||
lean_assert_eq(s2.size(), s.size() + 1);
|
||||
for (unsigned i = 0; i < 100; i++) {
|
||||
lean_assert(s.contains(i));
|
||||
lean_assert(s2.contains(i));
|
||||
}
|
||||
lean_assert(!s.contains(200));
|
||||
lean_assert(s2.contains(200));
|
||||
}
|
||||
|
||||
static void tst2() {
|
||||
int_rb_tree s;
|
||||
s.insert(10);
|
||||
s.insert(11);
|
||||
s.insert(9);
|
||||
std::cout << s << "\n";
|
||||
int_rb_tree s2 = s;
|
||||
std::cout << s2 << "\n";
|
||||
s.insert(20);
|
||||
std::cout << s << "\n";
|
||||
s.insert(15);
|
||||
}
|
||||
|
||||
static void tst3() {
|
||||
int_rb_tree s;
|
||||
s.insert(10);
|
||||
s.insert(3);
|
||||
s.insert(20);
|
||||
std::cout << s << "\n";
|
||||
s.insert(40);
|
||||
std::cout << s << "\n";
|
||||
s.insert(5);
|
||||
std::cout << s << "\n";
|
||||
s.insert(11);
|
||||
std::cout << s << "\n";
|
||||
s.insert(20);
|
||||
std::cout << s << "\n";
|
||||
s.insert(30);
|
||||
std::cout << s << "\n";
|
||||
s.insert(25);
|
||||
std::cout << s << "\n";
|
||||
s.insert(15);
|
||||
lean_assert(s.contains(40));
|
||||
lean_assert(s.contains(11));
|
||||
lean_assert(s.contains(20));
|
||||
lean_assert(s.contains(25));
|
||||
lean_assert(s.contains(5));
|
||||
lean_assert(s.contains(10));
|
||||
lean_assert(s.contains(3));
|
||||
lean_assert(s.contains(20));
|
||||
std::cout << s << "\n";
|
||||
int_rb_tree s2(s);
|
||||
std::cout << s2 << "\n";
|
||||
s.insert(34);
|
||||
std::cout << s2 << "\n";
|
||||
std::cout << s << "\n";
|
||||
int const * v = s.find(11);
|
||||
lean_assert(*v == 11);
|
||||
std::cout << s << "\n";
|
||||
lean_assert(!s.empty());
|
||||
s.clear();
|
||||
lean_assert(s.empty());
|
||||
}
|
||||
|
||||
static bool operator==(int_set const & v1, int_rb_tree const & v2) {
|
||||
buffer<int> b;
|
||||
// std::cout << v2 << "\n";
|
||||
// std::for_each(v1.begin(), v1.end(), [](int v) { std::cout << v << " "; }); std::cout << "\n";
|
||||
v2.to_buffer(b);
|
||||
if (v1.size() != b.size())
|
||||
return false;
|
||||
for (unsigned i = 0; i < b.size(); i++) {
|
||||
if (v1.find(b[i]) == v1.end())
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void driver(unsigned max_sz, unsigned max_val, unsigned num_ops, double insert_freq, double copy_freq) {
|
||||
int_set v1;
|
||||
int_rb_tree v2;
|
||||
int_rb_tree v3;
|
||||
std::mt19937 rng;
|
||||
size_t acc_sz = 0;
|
||||
size_t acc_depth = 0;
|
||||
rng.seed(static_cast<unsigned int>(time(0)));
|
||||
std::uniform_int_distribution<unsigned int> uint_dist;
|
||||
|
||||
std::vector<int_rb_tree> copies;
|
||||
for (unsigned i = 0; i < num_ops; i++) {
|
||||
acc_sz += v1.size();
|
||||
acc_depth += v2.get_depth();
|
||||
double f = static_cast<double>(uint_dist(rng) % 10000) / 10000.0;
|
||||
if (f < copy_freq) {
|
||||
copies.push_back(v2);
|
||||
}
|
||||
f = static_cast<double>(uint_dist(rng) % 10000) / 10000.0;
|
||||
// read random positions of v3
|
||||
for (unsigned int j = 0; j < uint_dist(rng) % 5; j++) {
|
||||
int a = uint_dist(rng) % max_val;
|
||||
lean_assert(v3.contains(a) == (v1.find(a) != v1.end()));
|
||||
}
|
||||
if (f < insert_freq) {
|
||||
if (v1.size() >= max_sz)
|
||||
continue;
|
||||
int a = uint_dist(rng) % max_val;
|
||||
v1.insert(a);
|
||||
v2.insert(a);
|
||||
v3 = insert(v3, a);
|
||||
} else {
|
||||
int a = uint_dist(rng) % max_val;
|
||||
v1.erase(a);
|
||||
v2.erase(a);
|
||||
v3 = erase(v3, a);
|
||||
}
|
||||
lean_assert(v1 == v2);
|
||||
lean_assert(v1 == v3);
|
||||
lean_assert(v1.size() == v2.size());
|
||||
}
|
||||
std::cout << "\n";
|
||||
std::cout << "Copies created: " << copies.size() << "\n";
|
||||
std::cout << "Average size: " << static_cast<double>(acc_sz) / static_cast<double>(num_ops) << "\n";
|
||||
std::cout << "Average depth: " << static_cast<double>(acc_depth) / static_cast<double>(num_ops) << "\n";
|
||||
}
|
||||
|
||||
static void tst4() {
|
||||
driver(4, 32, 10000, 0.5, 0.01);
|
||||
driver(4, 10000, 10000, 0.5, 0.01);
|
||||
driver(16, 16, 10000, 0.5, 0.1);
|
||||
driver(128, 64, 10000, 0.5, 0.1);
|
||||
driver(128, 64, 10000, 0.4, 0.1);
|
||||
driver(128, 1000, 10000, 0.5, 0.5);
|
||||
driver(128, 1000, 10000, 0.5, 0.01);
|
||||
driver(1024, 10000, 10000, 0.8, 0.01);
|
||||
}
|
||||
|
||||
|
||||
static void tst5() {
|
||||
int_rb_tree s;
|
||||
s.insert(10);
|
||||
s.insert(20);
|
||||
lean_assert(s.find(30) == nullptr);
|
||||
lean_assert(*(s.find(20)) == 20);
|
||||
lean_assert(*(s.find(10)) == 10);
|
||||
}
|
||||
|
||||
int main() {
|
||||
tst1();
|
||||
tst2();
|
||||
tst3();
|
||||
tst4();
|
||||
tst5();
|
||||
return has_violations() ? 1 : 0;
|
||||
}
|
||||
|
309
src/util/rb_tree.h
Normal file
309
src/util/rb_tree.h
Normal file
|
@ -0,0 +1,309 @@
|
|||
/*
|
||||
Copyright (c) 2014 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#pragma once
|
||||
#include <utility>
|
||||
#include "util/rc.h"
|
||||
#include "util/debug.h"
|
||||
#include "util/buffer.h"
|
||||
|
||||
namespace lean {
|
||||
/**
|
||||
\brief Left-leaning Red-Black Trees
|
||||
|
||||
It uses a O(1) copy operation. Different trees can share nodes.
|
||||
The sharing is thread-safe.
|
||||
|
||||
\c CMP is a functional object for comparing values of type T.
|
||||
It must have a method
|
||||
<code>
|
||||
int operator()(T const & v1, T const & v2) const;
|
||||
</code>
|
||||
The method must return
|
||||
- -1 if <tt>v1 < v2</tt>,
|
||||
- 0 if <tt>v1 == v2</tt>,
|
||||
- 1 if <tt>v1 > v2</tt>
|
||||
*/
|
||||
template<typename T, typename CMP>
|
||||
class rb_tree : public CMP {
|
||||
struct node_cell;
|
||||
struct node {
|
||||
node_cell * m_ptr;
|
||||
node():m_ptr(nullptr) {}
|
||||
node(node_cell * ptr):m_ptr(ptr) { if (m_ptr) ptr->inc_ref(); }
|
||||
node(node const & s):m_ptr(s.m_ptr) { if (m_ptr) m_ptr->inc_ref(); }
|
||||
node(node && s):m_ptr(s.m_ptr) { s.m_ptr = nullptr; }
|
||||
~node() { if (m_ptr) m_ptr->dec_ref(); }
|
||||
node & operator=(node const & n) { LEAN_COPY_REF(n); }
|
||||
node & operator=(node&& n) { LEAN_MOVE_REF(n); }
|
||||
operator bool() const { return m_ptr != nullptr; }
|
||||
bool is_shared() const { return m_ptr && m_ptr->get_rc() > 1; }
|
||||
bool is_red() const { return m_ptr && m_ptr->m_red; }
|
||||
bool is_black() const { return !is_red(); }
|
||||
node_cell * operator->() const { lean_assert(m_ptr); return m_ptr; }
|
||||
friend bool is_eqp(node const & n1, node const & n2) { return n1.m_ptr == n2.m_ptr; }
|
||||
friend void swap(node & n1, node & n2) { std::swap(n1.m_ptr, n2.m_ptr); }
|
||||
node steal() { node r; swap(r, *this); return r; }
|
||||
};
|
||||
|
||||
struct node_cell {
|
||||
node m_left;
|
||||
node m_right;
|
||||
T m_value;
|
||||
bool m_red;
|
||||
MK_LEAN_RC();
|
||||
void dealloc() { delete this; }
|
||||
node_cell(T const & v):m_value(v), m_red(true), m_rc(0) {}
|
||||
node_cell(node_cell const & s):m_left(s.m_left), m_right(s.m_right), m_value(s.m_value), m_red(s.m_red), m_rc(0) {}
|
||||
};
|
||||
|
||||
int cmp(T const & v1, T const & v2) const {
|
||||
return CMP::operator()(v1, v2);
|
||||
}
|
||||
|
||||
static node ensure_unshared(node && n) {
|
||||
if (n.is_shared()) {
|
||||
// std::cout << "SHARED\n";
|
||||
return node(new node_cell(*n.m_ptr));
|
||||
} else
|
||||
return n;
|
||||
}
|
||||
|
||||
static node set_black(node && n) {
|
||||
if (n.is_black())
|
||||
return n;
|
||||
node r = ensure_unshared(n.steal());
|
||||
r->m_red = false;
|
||||
return r;
|
||||
}
|
||||
|
||||
static node rotate_left(node && h) {
|
||||
lean_assert(!h.is_shared());
|
||||
node x = ensure_unshared(h->m_right.steal());
|
||||
lean_assert(!h->m_right); // x stole the ownership of h->m_right
|
||||
h->m_right = x->m_left;
|
||||
x->m_left = h;
|
||||
x->m_red = h->m_red;
|
||||
h->m_red = true;
|
||||
return x;
|
||||
}
|
||||
|
||||
static node rotate_right(node && h) {
|
||||
lean_assert(!h.is_shared());
|
||||
node x = ensure_unshared(h->m_left.steal());
|
||||
lean_assert(!h->m_left); // x stole the ownership of h->m_left
|
||||
h->m_left = x->m_right;
|
||||
x->m_right = h;
|
||||
x->m_red = h->m_red;
|
||||
h->m_red = true;
|
||||
return x;
|
||||
}
|
||||
|
||||
static node flip_colors(node && h) {
|
||||
lean_assert(!h.is_shared());
|
||||
h->m_red = !h->m_red;
|
||||
h->m_left = ensure_unshared(h->m_left.steal());
|
||||
h->m_right = ensure_unshared(h->m_right.steal());
|
||||
h->m_left->m_red = !h->m_left->m_red;
|
||||
h->m_right->m_red = !h->m_right->m_red;
|
||||
return h;
|
||||
}
|
||||
|
||||
static node fixup(node && h) {
|
||||
lean_assert(!h.is_shared());
|
||||
if (h->m_right.is_red() && !h->m_left.is_red())
|
||||
h = rotate_left(h.steal());
|
||||
if (h->m_left.is_red() && h->m_left->m_left.is_red())
|
||||
h = rotate_right(h.steal());
|
||||
if (h->m_left.is_red() && h->m_right.is_red())
|
||||
h = flip_colors(h.steal());
|
||||
return h;
|
||||
}
|
||||
|
||||
node insert(node && n, T const & v) {
|
||||
if (!n)
|
||||
return node(new node_cell(v));
|
||||
node h = ensure_unshared(n.steal());
|
||||
|
||||
int c = cmp(v, h->m_value);
|
||||
if (c == 0)
|
||||
h->m_value = v;
|
||||
else if (c < 0)
|
||||
h->m_left = insert(h->m_left.steal(), v);
|
||||
else
|
||||
h->m_right = insert(h->m_right.steal(), v);
|
||||
return fixup(h.steal());
|
||||
}
|
||||
|
||||
static node move_red_left(node && h) {
|
||||
lean_assert(!h.is_shared());
|
||||
h = flip_colors(h.steal());
|
||||
if (h->m_right && h->m_right->m_left.is_red()) {
|
||||
h->m_right = rotate_right(h->m_right.steal());
|
||||
h = rotate_left(h.steal());
|
||||
return flip_colors(h.steal());
|
||||
} else {
|
||||
return h;
|
||||
}
|
||||
}
|
||||
|
||||
static node move_red_right(node && h) {
|
||||
lean_assert(!h.is_shared());
|
||||
h = flip_colors(h.steal());
|
||||
if (h->m_left && h->m_left->m_left.is_red()) {
|
||||
h = rotate_right(h.steal());
|
||||
return flip_colors(h.steal());
|
||||
} else {
|
||||
return h;
|
||||
}
|
||||
}
|
||||
|
||||
static node erase_min(node && n) {
|
||||
if (!n->m_left)
|
||||
return node();
|
||||
node h = ensure_unshared(n.steal());
|
||||
if (!h->m_left.is_red() && !h->m_left->m_left.is_red())
|
||||
h = move_red_left(h.steal());
|
||||
h->m_left = erase_min(h->m_left.steal());
|
||||
return fixup(h.steal());
|
||||
}
|
||||
|
||||
static T const * min(node const & n) {
|
||||
node_cell const * it = n.m_ptr;
|
||||
if (!it)
|
||||
return nullptr;
|
||||
while (it->m_left)
|
||||
it = it->m_left.m_ptr;
|
||||
return &it->m_value;
|
||||
}
|
||||
|
||||
node erase(node && n, T const & v) {
|
||||
lean_assert(n);
|
||||
node h = ensure_unshared(n.steal());
|
||||
if (cmp(v, h->m_value) < 0) {
|
||||
lean_assert(h->m_left); // the tree contains v
|
||||
if (!h->m_left.is_red() && !h->m_left->m_left.is_red())
|
||||
h = move_red_left(h.steal());
|
||||
h->m_left = erase(h->m_left.steal(), v);
|
||||
} else {
|
||||
if (h->m_left.is_red())
|
||||
h = rotate_right(h.steal());
|
||||
if (cmp(v, h->m_value) == 0 && !h->m_right)
|
||||
return node();
|
||||
lean_assert(h->m_right);
|
||||
if (!h->m_right.is_red() && !h->m_right->m_left.is_red())
|
||||
h = move_red_right(h.steal());
|
||||
if (cmp(v, h->m_value) == 0) {
|
||||
h->m_value = *min(h->m_right);
|
||||
h->m_right = erase_min(h->m_right.steal());
|
||||
} else {
|
||||
h->m_right = erase(h->m_right.steal(), v);
|
||||
}
|
||||
}
|
||||
return fixup(h.steal());
|
||||
}
|
||||
|
||||
template<typename F>
|
||||
static void for_each(F && f, node_cell const * n) {
|
||||
if (n) {
|
||||
for_each(f, n->m_left.m_ptr);
|
||||
f(n->m_value);
|
||||
for_each(f, n->m_right.m_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
static void display(std::ostream & out, node_cell const * n) {
|
||||
if (n) {
|
||||
out << "(";
|
||||
if (n->m_red)
|
||||
out << "*";
|
||||
out << n->m_value << " ";
|
||||
display(out, n->m_left.m_ptr);
|
||||
out << " ";
|
||||
display(out, n->m_right.m_ptr);
|
||||
out << ")";
|
||||
} else {
|
||||
out << "nil";
|
||||
}
|
||||
}
|
||||
|
||||
static unsigned get_depth(node_cell const * n) {
|
||||
if (n)
|
||||
return std::max(get_depth(n->m_left.m_ptr), get_depth(n->m_right.m_ptr)) + 1;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void to_buffer(node_cell const * n, buffer<T> & r) {
|
||||
if (n) {
|
||||
to_buffer(n->m_left.m_ptr, r);
|
||||
r.push_back(n->m_value);
|
||||
to_buffer(n->m_right.m_ptr, r);
|
||||
}
|
||||
}
|
||||
|
||||
node m_root;
|
||||
|
||||
public:
|
||||
void insert(T const & v) { m_root = set_black(insert(m_root.steal(), v)); }
|
||||
void erase_min(T const & v) { m_root = set_black(erase_min(m_root.steal())); }
|
||||
void erase_core(T const & v) { lean_assert(contains(v)); m_root = set_black(erase(m_root.steal(), v)); }
|
||||
void erase(T const & v) { if (contains(v)) erase_core(v); }
|
||||
|
||||
T const * find(T const & v) const {
|
||||
node_cell const * h = m_root.m_ptr;
|
||||
while (h) {
|
||||
int c = cmp(v, h->m_value);
|
||||
if (c == 0)
|
||||
return &(h->m_value);
|
||||
else if (c < 0)
|
||||
h = h->m_left.m_ptr;
|
||||
else
|
||||
h = h->m_right.m_ptr;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
T const * min() const { return min(m_root); }
|
||||
bool contains(T const & v) const { return find(v) != nullptr; }
|
||||
|
||||
template<typename F>
|
||||
void for_each(F && f) const { for_each(f, m_root.m_ptr); }
|
||||
|
||||
// For debugging purposes
|
||||
void display(std::ostream & out) const { display(out, m_root.m_ptr); }
|
||||
|
||||
unsigned get_depth() const { return get_depth(m_root.m_ptr); }
|
||||
|
||||
unsigned size() const {
|
||||
unsigned r = 0;
|
||||
for_each([&](T const & ){ r = r + 1; });
|
||||
return r;
|
||||
}
|
||||
|
||||
bool empty() const { return m_root.m_ptr == nullptr; }
|
||||
|
||||
void clear() { m_root = node(); }
|
||||
|
||||
friend std::ostream & operator<<(std::ostream & out, rb_tree const & t) {
|
||||
t.display(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Copy the contents of this tree to the given buffer.
|
||||
The elements will be stored in increasing order.
|
||||
*/
|
||||
void to_buffer(buffer<T> & r) const {
|
||||
to_buffer(m_root.m_ptr, r);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename CMP>
|
||||
rb_tree<T, CMP> insert(rb_tree<T, CMP> & t, T const & v) { rb_tree<T, CMP> r(t); r.insert(v); return r; }
|
||||
template<typename T, typename CMP>
|
||||
rb_tree<T, CMP> erase(rb_tree<T, CMP> & t, T const & v) { rb_tree<T, CMP> r(t); r.erase(v); return r; }
|
||||
}
|
Loading…
Reference in a new issue