From 528ea367adddfb91ca5b62bb9e50f9d06d11b7ce Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 21 Feb 2014 17:20:14 -0800 Subject: [PATCH] feat(util): add red-black trees Signed-off-by: Leonardo de Moura --- src/tests/util/CMakeLists.txt | 3 + src/tests/util/rb_tree.cpp | 190 +++++++++++++++++++++ src/util/rb_tree.h | 309 ++++++++++++++++++++++++++++++++++ 3 files changed, 502 insertions(+) create mode 100644 src/tests/util/rb_tree.cpp create mode 100644 src/util/rb_tree.h diff --git a/src/tests/util/CMakeLists.txt b/src/tests/util/CMakeLists.txt index 705983a74..7ec77af46 100644 --- a/src/tests/util/CMakeLists.txt +++ b/src/tests/util/CMakeLists.txt @@ -28,6 +28,9 @@ add_test(thread ${CMAKE_CURRENT_BINARY_DIR}/thread) add_executable(memory memory.cpp) target_link_libraries(memory ${EXTRA_LIBS}) add_test(memory ${CMAKE_CURRENT_BINARY_DIR}/memory) +add_executable(rb_tree rb_tree.cpp) +target_link_libraries(rb_tree ${EXTRA_LIBS}) +add_test(rb_tree ${CMAKE_CURRENT_BINARY_DIR}/rb_tree) add_executable(splay_tree splay_tree.cpp) target_link_libraries(splay_tree ${EXTRA_LIBS}) add_test(splay_tree ${CMAKE_CURRENT_BINARY_DIR}/splay_tree) diff --git a/src/tests/util/rb_tree.cpp b/src/tests/util/rb_tree.cpp new file mode 100644 index 000000000..5e532db38 --- /dev/null +++ b/src/tests/util/rb_tree.cpp @@ -0,0 +1,190 @@ +/* +Copyright (c) 2013 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include +#include +#include +#include +#include +#include +#include "util/test.h" +#include "util/buffer.h" +#include "util/rb_tree.h" +#include "util/timeit.h" +using namespace lean; + +struct int_lt { int operator()(int i1, int i2) const { return i1 < i2 ? -1 : (i1 > i2 ? 1 : 0); } }; +typedef rb_tree int_rb_tree; +typedef std::unordered_set int_set; + +void print(int_rb_tree const & t) { + std::cout << t << "\n"; +} + +static void tst1() { + int_rb_tree s; + for (unsigned i = 0; i < 100; i++) { + s.insert(i); + } + std::cout << s << "\n"; + std::cout << "DEPTH: " << s.get_depth() << "\n"; + int_rb_tree s2 = s; + std::cout << "DEPTH: " << s2.get_depth() << "\n"; + s2.insert(200); + lean_assert_eq(s2.size(), s.size() + 1); + for (unsigned i = 0; i < 100; i++) { + lean_assert(s.contains(i)); + lean_assert(s2.contains(i)); + } + lean_assert(!s.contains(200)); + lean_assert(s2.contains(200)); +} + +static void tst2() { + int_rb_tree s; + s.insert(10); + s.insert(11); + s.insert(9); + std::cout << s << "\n"; + int_rb_tree s2 = s; + std::cout << s2 << "\n"; + s.insert(20); + std::cout << s << "\n"; + s.insert(15); +} + +static void tst3() { + int_rb_tree s; + s.insert(10); + s.insert(3); + s.insert(20); + std::cout << s << "\n"; + s.insert(40); + std::cout << s << "\n"; + s.insert(5); + std::cout << s << "\n"; + s.insert(11); + std::cout << s << "\n"; + s.insert(20); + std::cout << s << "\n"; + s.insert(30); + std::cout << s << "\n"; + s.insert(25); + std::cout << s << "\n"; + s.insert(15); + lean_assert(s.contains(40)); + lean_assert(s.contains(11)); + lean_assert(s.contains(20)); + lean_assert(s.contains(25)); + lean_assert(s.contains(5)); + lean_assert(s.contains(10)); + lean_assert(s.contains(3)); + lean_assert(s.contains(20)); + std::cout << s << "\n"; + int_rb_tree s2(s); + std::cout << s2 << "\n"; + s.insert(34); + std::cout << s2 << "\n"; + std::cout << s << "\n"; + int const * v = s.find(11); + lean_assert(*v == 11); + std::cout << s << "\n"; + lean_assert(!s.empty()); + s.clear(); + lean_assert(s.empty()); +} + +static bool operator==(int_set const & v1, int_rb_tree const & v2) { + buffer b; + // std::cout << v2 << "\n"; + // std::for_each(v1.begin(), v1.end(), [](int v) { std::cout << v << " "; }); std::cout << "\n"; + v2.to_buffer(b); + if (v1.size() != b.size()) + return false; + for (unsigned i = 0; i < b.size(); i++) { + if (v1.find(b[i]) == v1.end()) + return false; + } + return true; +} + +static void driver(unsigned max_sz, unsigned max_val, unsigned num_ops, double insert_freq, double copy_freq) { + int_set v1; + int_rb_tree v2; + int_rb_tree v3; + std::mt19937 rng; + size_t acc_sz = 0; + size_t acc_depth = 0; + rng.seed(static_cast(time(0))); + std::uniform_int_distribution uint_dist; + + std::vector copies; + for (unsigned i = 0; i < num_ops; i++) { + acc_sz += v1.size(); + acc_depth += v2.get_depth(); + double f = static_cast(uint_dist(rng) % 10000) / 10000.0; + if (f < copy_freq) { + copies.push_back(v2); + } + f = static_cast(uint_dist(rng) % 10000) / 10000.0; + // read random positions of v3 + for (unsigned int j = 0; j < uint_dist(rng) % 5; j++) { + int a = uint_dist(rng) % max_val; + lean_assert(v3.contains(a) == (v1.find(a) != v1.end())); + } + if (f < insert_freq) { + if (v1.size() >= max_sz) + continue; + int a = uint_dist(rng) % max_val; + v1.insert(a); + v2.insert(a); + v3 = insert(v3, a); + } else { + int a = uint_dist(rng) % max_val; + v1.erase(a); + v2.erase(a); + v3 = erase(v3, a); + } + lean_assert(v1 == v2); + lean_assert(v1 == v3); + lean_assert(v1.size() == v2.size()); + } + std::cout << "\n"; + std::cout << "Copies created: " << copies.size() << "\n"; + std::cout << "Average size: " << static_cast(acc_sz) / static_cast(num_ops) << "\n"; + std::cout << "Average depth: " << static_cast(acc_depth) / static_cast(num_ops) << "\n"; +} + +static void tst4() { + driver(4, 32, 10000, 0.5, 0.01); + driver(4, 10000, 10000, 0.5, 0.01); + driver(16, 16, 10000, 0.5, 0.1); + driver(128, 64, 10000, 0.5, 0.1); + driver(128, 64, 10000, 0.4, 0.1); + driver(128, 1000, 10000, 0.5, 0.5); + driver(128, 1000, 10000, 0.5, 0.01); + driver(1024, 10000, 10000, 0.8, 0.01); +} + + +static void tst5() { + int_rb_tree s; + s.insert(10); + s.insert(20); + lean_assert(s.find(30) == nullptr); + lean_assert(*(s.find(20)) == 20); + lean_assert(*(s.find(10)) == 10); +} + +int main() { + tst1(); + tst2(); + tst3(); + tst4(); + tst5(); + return has_violations() ? 1 : 0; +} + diff --git a/src/util/rb_tree.h b/src/util/rb_tree.h new file mode 100644 index 000000000..29db1d32f --- /dev/null +++ b/src/util/rb_tree.h @@ -0,0 +1,309 @@ +/* +Copyright (c) 2014 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include +#include "util/rc.h" +#include "util/debug.h" +#include "util/buffer.h" + +namespace lean { +/** + \brief Left-leaning Red-Black Trees + + It uses a O(1) copy operation. Different trees can share nodes. + The sharing is thread-safe. + + \c CMP is a functional object for comparing values of type T. + It must have a method + + int operator()(T const & v1, T const & v2) const; + + The method must return + - -1 if v1 < v2, + - 0 if v1 == v2, + - 1 if v1 > v2 +*/ +template +class rb_tree : public CMP { + struct node_cell; + struct node { + node_cell * m_ptr; + node():m_ptr(nullptr) {} + node(node_cell * ptr):m_ptr(ptr) { if (m_ptr) ptr->inc_ref(); } + node(node const & s):m_ptr(s.m_ptr) { if (m_ptr) m_ptr->inc_ref(); } + node(node && s):m_ptr(s.m_ptr) { s.m_ptr = nullptr; } + ~node() { if (m_ptr) m_ptr->dec_ref(); } + node & operator=(node const & n) { LEAN_COPY_REF(n); } + node & operator=(node&& n) { LEAN_MOVE_REF(n); } + operator bool() const { return m_ptr != nullptr; } + bool is_shared() const { return m_ptr && m_ptr->get_rc() > 1; } + bool is_red() const { return m_ptr && m_ptr->m_red; } + bool is_black() const { return !is_red(); } + node_cell * operator->() const { lean_assert(m_ptr); return m_ptr; } + friend bool is_eqp(node const & n1, node const & n2) { return n1.m_ptr == n2.m_ptr; } + friend void swap(node & n1, node & n2) { std::swap(n1.m_ptr, n2.m_ptr); } + node steal() { node r; swap(r, *this); return r; } + }; + + struct node_cell { + node m_left; + node m_right; + T m_value; + bool m_red; + MK_LEAN_RC(); + void dealloc() { delete this; } + node_cell(T const & v):m_value(v), m_red(true), m_rc(0) {} + node_cell(node_cell const & s):m_left(s.m_left), m_right(s.m_right), m_value(s.m_value), m_red(s.m_red), m_rc(0) {} + }; + + int cmp(T const & v1, T const & v2) const { + return CMP::operator()(v1, v2); + } + + static node ensure_unshared(node && n) { + if (n.is_shared()) { + // std::cout << "SHARED\n"; + return node(new node_cell(*n.m_ptr)); + } else + return n; + } + + static node set_black(node && n) { + if (n.is_black()) + return n; + node r = ensure_unshared(n.steal()); + r->m_red = false; + return r; + } + + static node rotate_left(node && h) { + lean_assert(!h.is_shared()); + node x = ensure_unshared(h->m_right.steal()); + lean_assert(!h->m_right); // x stole the ownership of h->m_right + h->m_right = x->m_left; + x->m_left = h; + x->m_red = h->m_red; + h->m_red = true; + return x; + } + + static node rotate_right(node && h) { + lean_assert(!h.is_shared()); + node x = ensure_unshared(h->m_left.steal()); + lean_assert(!h->m_left); // x stole the ownership of h->m_left + h->m_left = x->m_right; + x->m_right = h; + x->m_red = h->m_red; + h->m_red = true; + return x; + } + + static node flip_colors(node && h) { + lean_assert(!h.is_shared()); + h->m_red = !h->m_red; + h->m_left = ensure_unshared(h->m_left.steal()); + h->m_right = ensure_unshared(h->m_right.steal()); + h->m_left->m_red = !h->m_left->m_red; + h->m_right->m_red = !h->m_right->m_red; + return h; + } + + static node fixup(node && h) { + lean_assert(!h.is_shared()); + if (h->m_right.is_red() && !h->m_left.is_red()) + h = rotate_left(h.steal()); + if (h->m_left.is_red() && h->m_left->m_left.is_red()) + h = rotate_right(h.steal()); + if (h->m_left.is_red() && h->m_right.is_red()) + h = flip_colors(h.steal()); + return h; + } + + node insert(node && n, T const & v) { + if (!n) + return node(new node_cell(v)); + node h = ensure_unshared(n.steal()); + + int c = cmp(v, h->m_value); + if (c == 0) + h->m_value = v; + else if (c < 0) + h->m_left = insert(h->m_left.steal(), v); + else + h->m_right = insert(h->m_right.steal(), v); + return fixup(h.steal()); + } + + static node move_red_left(node && h) { + lean_assert(!h.is_shared()); + h = flip_colors(h.steal()); + if (h->m_right && h->m_right->m_left.is_red()) { + h->m_right = rotate_right(h->m_right.steal()); + h = rotate_left(h.steal()); + return flip_colors(h.steal()); + } else { + return h; + } + } + + static node move_red_right(node && h) { + lean_assert(!h.is_shared()); + h = flip_colors(h.steal()); + if (h->m_left && h->m_left->m_left.is_red()) { + h = rotate_right(h.steal()); + return flip_colors(h.steal()); + } else { + return h; + } + } + + static node erase_min(node && n) { + if (!n->m_left) + return node(); + node h = ensure_unshared(n.steal()); + if (!h->m_left.is_red() && !h->m_left->m_left.is_red()) + h = move_red_left(h.steal()); + h->m_left = erase_min(h->m_left.steal()); + return fixup(h.steal()); + } + + static T const * min(node const & n) { + node_cell const * it = n.m_ptr; + if (!it) + return nullptr; + while (it->m_left) + it = it->m_left.m_ptr; + return &it->m_value; + } + + node erase(node && n, T const & v) { + lean_assert(n); + node h = ensure_unshared(n.steal()); + if (cmp(v, h->m_value) < 0) { + lean_assert(h->m_left); // the tree contains v + if (!h->m_left.is_red() && !h->m_left->m_left.is_red()) + h = move_red_left(h.steal()); + h->m_left = erase(h->m_left.steal(), v); + } else { + if (h->m_left.is_red()) + h = rotate_right(h.steal()); + if (cmp(v, h->m_value) == 0 && !h->m_right) + return node(); + lean_assert(h->m_right); + if (!h->m_right.is_red() && !h->m_right->m_left.is_red()) + h = move_red_right(h.steal()); + if (cmp(v, h->m_value) == 0) { + h->m_value = *min(h->m_right); + h->m_right = erase_min(h->m_right.steal()); + } else { + h->m_right = erase(h->m_right.steal(), v); + } + } + return fixup(h.steal()); + } + + template + static void for_each(F && f, node_cell const * n) { + if (n) { + for_each(f, n->m_left.m_ptr); + f(n->m_value); + for_each(f, n->m_right.m_ptr); + } + } + + static void display(std::ostream & out, node_cell const * n) { + if (n) { + out << "("; + if (n->m_red) + out << "*"; + out << n->m_value << " "; + display(out, n->m_left.m_ptr); + out << " "; + display(out, n->m_right.m_ptr); + out << ")"; + } else { + out << "nil"; + } + } + + static unsigned get_depth(node_cell const * n) { + if (n) + return std::max(get_depth(n->m_left.m_ptr), get_depth(n->m_right.m_ptr)) + 1; + else + return 0; + } + + static void to_buffer(node_cell const * n, buffer & r) { + if (n) { + to_buffer(n->m_left.m_ptr, r); + r.push_back(n->m_value); + to_buffer(n->m_right.m_ptr, r); + } + } + + node m_root; + +public: + void insert(T const & v) { m_root = set_black(insert(m_root.steal(), v)); } + void erase_min(T const & v) { m_root = set_black(erase_min(m_root.steal())); } + void erase_core(T const & v) { lean_assert(contains(v)); m_root = set_black(erase(m_root.steal(), v)); } + void erase(T const & v) { if (contains(v)) erase_core(v); } + + T const * find(T const & v) const { + node_cell const * h = m_root.m_ptr; + while (h) { + int c = cmp(v, h->m_value); + if (c == 0) + return &(h->m_value); + else if (c < 0) + h = h->m_left.m_ptr; + else + h = h->m_right.m_ptr; + } + return nullptr; + } + + T const * min() const { return min(m_root); } + bool contains(T const & v) const { return find(v) != nullptr; } + + template + void for_each(F && f) const { for_each(f, m_root.m_ptr); } + + // For debugging purposes + void display(std::ostream & out) const { display(out, m_root.m_ptr); } + + unsigned get_depth() const { return get_depth(m_root.m_ptr); } + + unsigned size() const { + unsigned r = 0; + for_each([&](T const & ){ r = r + 1; }); + return r; + } + + bool empty() const { return m_root.m_ptr == nullptr; } + + void clear() { m_root = node(); } + + friend std::ostream & operator<<(std::ostream & out, rb_tree const & t) { + t.display(out); + return out; + } + + /** + \brief Copy the contents of this tree to the given buffer. + The elements will be stored in increasing order. + */ + void to_buffer(buffer & r) const { + to_buffer(m_root.m_ptr, r); + } +}; + +template +rb_tree insert(rb_tree & t, T const & v) { rb_tree r(t); r.insert(v); return r; } +template +rb_tree erase(rb_tree & t, T const & v) { rb_tree r(t); r.erase(v); return r; } +}