Add remaining splay tree methods
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
01f5fa59b1
commit
b78b2e0585
2 changed files with 174 additions and 25 deletions
|
@ -60,14 +60,17 @@ void tst1() {
|
||||||
lean_assert(s.contains(3));
|
lean_assert(s.contains(3));
|
||||||
lean_assert(s.contains(20));
|
lean_assert(s.contains(20));
|
||||||
std::cout << s << "\n";
|
std::cout << s << "\n";
|
||||||
std::cout << "BEFORE CONSTR\n";
|
|
||||||
int_splay_tree s2(s);
|
int_splay_tree s2(s);
|
||||||
std::cout << "AFTER CONSTR\n";
|
|
||||||
std::cout << s2 << "\n";
|
std::cout << s2 << "\n";
|
||||||
s.insert(34);
|
s.insert(34);
|
||||||
std::cout << s2 << "\n";
|
std::cout << s2 << "\n";
|
||||||
std::cout << s << "\n";
|
std::cout << s << "\n";
|
||||||
std::cout << "END\n";
|
int const * v = s.find_memoize(11);
|
||||||
|
lean_assert(*v == 11);
|
||||||
|
std::cout << s << "\n";
|
||||||
|
lean_assert(!s.empty());
|
||||||
|
s.clear();
|
||||||
|
lean_assert(s.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator==(int_set const & v1, int_splay_tree const & v2) {
|
bool operator==(int_set const & v1, int_splay_tree const & v2) {
|
||||||
|
@ -89,12 +92,13 @@ static void driver(unsigned max_sz, unsigned max_val, unsigned num_ops, double i
|
||||||
int_splay_tree v2;
|
int_splay_tree v2;
|
||||||
int_splay_tree v3;
|
int_splay_tree v3;
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
|
size_t acc_sz = 0;
|
||||||
rng.seed(static_cast<unsigned int>(time(0)));
|
rng.seed(static_cast<unsigned int>(time(0)));
|
||||||
std::uniform_int_distribution<unsigned int> uint_dist;
|
std::uniform_int_distribution<unsigned int> uint_dist;
|
||||||
|
|
||||||
std::vector<int_splay_tree> copies;
|
std::vector<int_splay_tree> copies;
|
||||||
for (unsigned i = 0; i < num_ops; i++) {
|
for (unsigned i = 0; i < num_ops; i++) {
|
||||||
|
acc_sz += v1.size();
|
||||||
double f = static_cast<double>(uint_dist(rng) % 10000) / 10000.0;
|
double f = static_cast<double>(uint_dist(rng) % 10000) / 10000.0;
|
||||||
if (f < copy_freq) {
|
if (f < copy_freq) {
|
||||||
copies.push_back(v2);
|
copies.push_back(v2);
|
||||||
|
@ -113,17 +117,26 @@ static void driver(unsigned max_sz, unsigned max_val, unsigned num_ops, double i
|
||||||
v2.insert(a);
|
v2.insert(a);
|
||||||
v3 = insert(v3, a);
|
v3 = insert(v3, a);
|
||||||
} else {
|
} else {
|
||||||
// TODO(Leo): erase operation for splay_trees
|
int a = uint_dist(rng) % max_val;
|
||||||
|
v1.erase(a);
|
||||||
|
v2.erase(a);
|
||||||
|
v3 = erase(v3, a);
|
||||||
}
|
}
|
||||||
lean_assert(v1 == v2);
|
lean_assert(v1 == v2);
|
||||||
lean_assert(v1 == v3);
|
lean_assert(v1 == v3);
|
||||||
|
lean_assert(v1.size() == v2.size());
|
||||||
}
|
}
|
||||||
|
std::cout << "\n";
|
||||||
std::cout << "Copies created: " << copies.size() << "\n";
|
std::cout << "Copies created: " << copies.size() << "\n";
|
||||||
|
std::cout << "Average size: " << static_cast<double>(acc_sz) / static_cast<double>(num_ops) << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
static void tst2() {
|
static void tst2() {
|
||||||
driver(4, 32, 10000, 0.5, 0.01);
|
driver(4, 32, 10000, 0.5, 0.01);
|
||||||
driver(4, 10000, 10000, 0.5, 0.01);
|
driver(4, 10000, 10000, 0.5, 0.01);
|
||||||
|
driver(16, 16, 10000, 0.5, 0.1);
|
||||||
|
driver(128, 64, 10000, 0.5, 0.1);
|
||||||
|
driver(128, 64, 10000, 0.4, 0.1);
|
||||||
driver(128, 1000, 10000, 0.5, 0.5);
|
driver(128, 1000, 10000, 0.5, 0.5);
|
||||||
driver(128, 1000, 10000, 0.5, 0.01);
|
driver(128, 1000, 10000, 0.5, 0.01);
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ Author: Leonardo de Moura
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <algorithm>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "util/rc.h"
|
#include "util/rc.h"
|
||||||
|
@ -14,7 +15,22 @@ Author: Leonardo de Moura
|
||||||
#include "util/buffer.h"
|
#include "util/buffer.h"
|
||||||
|
|
||||||
namespace lean {
|
namespace lean {
|
||||||
|
/**
|
||||||
|
\brief Splay trees (see http://en.wikipedia.org/wiki/Splay_tree)
|
||||||
|
|
||||||
|
It uses a O(1) copy operation. Different tree can share nodes.
|
||||||
|
The sharing is thread-safe.
|
||||||
|
|
||||||
|
\c CMP is a functional object for comparing values of type T.
|
||||||
|
It must have a method
|
||||||
|
<code>
|
||||||
|
int operator()(T const & v1, T const & v2) const;
|
||||||
|
</code>
|
||||||
|
The method must return
|
||||||
|
- -1 if <tt>v1 < v2</tt>,
|
||||||
|
- 0 if <tt>v1 == v2</tt>,
|
||||||
|
- 1 if <tt>v1 > v2</tt>
|
||||||
|
*/
|
||||||
template<typename T, typename CMP>
|
template<typename T, typename CMP>
|
||||||
class splay_tree : public CMP {
|
class splay_tree : public CMP {
|
||||||
struct node {
|
struct node {
|
||||||
|
@ -183,6 +199,7 @@ class splay_tree : public CMP {
|
||||||
child->inc_ref();
|
child->inc_ref();
|
||||||
entry const & last = path.back();
|
entry const & last = path.back();
|
||||||
node * parent = last.m_node;
|
node * parent = last.m_node;
|
||||||
|
lean_assert(!parent->is_shared());
|
||||||
if (last.m_right) {
|
if (last.m_right) {
|
||||||
node::dec_ref(parent->m_right);
|
node::dec_ref(parent->m_right);
|
||||||
parent->m_right = child;
|
parent->m_right = child;
|
||||||
|
@ -201,24 +218,21 @@ class splay_tree : public CMP {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
bool insert_pull(T const & v, bool is_insert) {
|
||||||
splay_tree(CMP const & cmp = CMP()):CMP(cmp), m_ptr(nullptr) {}
|
|
||||||
splay_tree(splay_tree const & s):CMP(s), m_ptr(s.m_ptr) { node::inc_ref(m_ptr); }
|
|
||||||
splay_tree(splay_tree && s):CMP(s), m_ptr(s.m_ptr) { s.m_ptr = nullptr; }
|
|
||||||
~splay_tree() { node::dec_ref(m_ptr); }
|
|
||||||
|
|
||||||
splay_tree & operator=(splay_tree const & s) { LEAN_COPY_REF(splay_tree, s); }
|
|
||||||
splay_tree & operator=(splay_tree && s) { LEAN_MOVE_REF(splay_tree, s); }
|
|
||||||
|
|
||||||
bool empty() const { return m_ptr == nullptr; }
|
|
||||||
|
|
||||||
void insert(T const & v) {
|
|
||||||
static thread_local std::vector<entry> path;
|
static thread_local std::vector<entry> path;
|
||||||
node * n = m_ptr;
|
node * n = m_ptr;
|
||||||
|
bool found = false;
|
||||||
while (true) {
|
while (true) {
|
||||||
if (n == nullptr) {
|
if (n == nullptr) {
|
||||||
|
if (is_insert) {
|
||||||
n = new node(v);
|
n = new node(v);
|
||||||
update_parent(path, n);
|
update_parent(path, n);
|
||||||
|
} else {
|
||||||
|
if (path.empty())
|
||||||
|
return false;
|
||||||
|
n = path.back().m_node;
|
||||||
|
path.pop_back();
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
if (n->is_shared()) {
|
if (n->is_shared()) {
|
||||||
|
@ -234,39 +248,159 @@ public:
|
||||||
path.push_back(entry(true, n));
|
path.push_back(entry(true, n));
|
||||||
n = n->m_right;
|
n = n->m_right;
|
||||||
} else {
|
} else {
|
||||||
|
if (is_insert)
|
||||||
n->m_value = v;
|
n->m_value = v;
|
||||||
|
found = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
splay_to_top(path, n);
|
splay_to_top(path, n);
|
||||||
m_ptr = n;
|
m_ptr = n;
|
||||||
lean_assert(check_invariant())
|
lean_assert(check_invariant());
|
||||||
|
return found;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool contains(T const & v) const {
|
bool pull(T const & v) {
|
||||||
|
return insert_pull(v, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void pull_max() {
|
||||||
|
if (!m_ptr) return;
|
||||||
|
static thread_local std::vector<entry> path;
|
||||||
|
node * n = m_ptr;
|
||||||
|
while (true) {
|
||||||
|
lean_assert(n);
|
||||||
|
if (n->is_shared()) {
|
||||||
|
n = new node(*n);
|
||||||
|
update_parent(path, n);
|
||||||
|
}
|
||||||
|
if (n->m_right) {
|
||||||
|
path.push_back(entry(true, n));
|
||||||
|
n = n->m_right;
|
||||||
|
} else {
|
||||||
|
splay_to_top(path, n);
|
||||||
|
m_ptr = n;
|
||||||
|
lean_assert(check_invariant());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static unsigned size(node const * n) {
|
||||||
|
if (n)
|
||||||
|
return 1 + size(n->m_left) + size(n->m_right);
|
||||||
|
else
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
splay_tree(splay_tree const & s, node * new_root):CMP(s), m_ptr(new_root) { node::inc_ref(m_ptr); }
|
||||||
|
|
||||||
|
public:
|
||||||
|
splay_tree(CMP const & cmp = CMP()):CMP(cmp), m_ptr(nullptr) {}
|
||||||
|
splay_tree(splay_tree const & s):CMP(s), m_ptr(s.m_ptr) { node::inc_ref(m_ptr); }
|
||||||
|
splay_tree(splay_tree && s):CMP(s), m_ptr(s.m_ptr) { s.m_ptr = nullptr; }
|
||||||
|
~splay_tree() { node::dec_ref(m_ptr); }
|
||||||
|
|
||||||
|
/** \brief O(1) copy */
|
||||||
|
splay_tree & operator=(splay_tree const & s) { LEAN_COPY_REF(splay_tree, s); }
|
||||||
|
/** \brief O(1) move */
|
||||||
|
splay_tree & operator=(splay_tree && s) { LEAN_MOVE_REF(splay_tree, s); }
|
||||||
|
|
||||||
|
void swap(splay_tree & t) { std::swap(m_ptr, t.m_ptr); }
|
||||||
|
|
||||||
|
/** \brief Return true iff this splay tree is empty. */
|
||||||
|
bool empty() const { return m_ptr == nullptr; }
|
||||||
|
|
||||||
|
/** \brief Remove all elements from the splay tree. */
|
||||||
|
void clear() { node::dec_ref(m_ptr); m_ptr = nullptr; }
|
||||||
|
|
||||||
|
/** \brief Return true iff this splay tree and \c t point to the same node */
|
||||||
|
bool is_eqp(splay_tree const & t) const { return m_ptr == t.m_ptr; }
|
||||||
|
|
||||||
|
/** \brief Return the size of the splay tree */
|
||||||
|
unsigned size() const { return size(m_ptr); }
|
||||||
|
|
||||||
|
/** \brief Insert \c v in this splay tree. */
|
||||||
|
void insert(T const & v) {
|
||||||
|
insert_pull(v, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
\brief Return a pointer to a value equal to \c v that is stored in this splay tree.
|
||||||
|
If the splay tree does not contain any value equal to \c v, then return \c nullptr.
|
||||||
|
|
||||||
|
\remark <tt>find(v) != nullptr</tt> iff <tt>contains(v)</tt>
|
||||||
|
*/
|
||||||
|
T const * find(T const & v) const {
|
||||||
node const * n = m_ptr;
|
node const * n = m_ptr;
|
||||||
while (true) {
|
while (true) {
|
||||||
if (n == nullptr)
|
if (n == nullptr)
|
||||||
return false;
|
return nullptr;
|
||||||
int c = cmp(v, n->m_value);
|
int c = cmp(v, n->m_value);
|
||||||
if (c < 0)
|
if (c < 0)
|
||||||
n = n->m_left;
|
n = n->m_left;
|
||||||
else if (c > 0)
|
else if (c > 0)
|
||||||
n = n->m_right;
|
n = n->m_right;
|
||||||
else
|
else
|
||||||
return true;
|
return &(n->m_value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** \brief Return true iff the splay tree contains an element equal to \c v. */
|
||||||
|
bool contains(T const & v) const {
|
||||||
|
return find(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
\brief Similar to \c find, but the splay tree is reorganized.
|
||||||
|
If <tt>find(v)</tt> is invoked after <tt>find_memoize(v)</tt>, then the cost will be O(1).
|
||||||
|
The idea is to move recently accessed elements close to the root.
|
||||||
|
*/
|
||||||
|
T const * find_memoize(T const & v) {
|
||||||
|
if (pull(v)) {
|
||||||
|
lean_assert(cmp(m_ptr->m_value, v) == 0);
|
||||||
|
return &(m_ptr->m_value);
|
||||||
|
} else {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** \brief Remove \c v from this splay tree. Actually, it removes an element that is equal to \c v. */
|
||||||
|
void erase(T const & v) {
|
||||||
|
if (pull(v)) {
|
||||||
|
lean_assert(cmp(m_ptr->m_value, v) == 0);
|
||||||
|
splay_tree left(*this, m_ptr->m_left);
|
||||||
|
splay_tree right(*this, m_ptr->m_right);
|
||||||
|
if (left.empty()) {
|
||||||
|
swap(right);
|
||||||
|
} else if (right.empty()) {
|
||||||
|
swap(left);
|
||||||
|
} else {
|
||||||
|
clear();
|
||||||
|
left.pull_max();
|
||||||
|
lean_assert(left.m_ptr->m_right == nullptr);
|
||||||
|
right.m_ptr->inc_ref();
|
||||||
|
left.m_ptr->m_right = right.m_ptr;
|
||||||
|
swap(left);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** \brief (For debugging) Check whether this splay tree is well formed. */
|
||||||
bool check_invariant() const {
|
bool check_invariant() const {
|
||||||
return check_invariant(m_ptr);
|
return check_invariant(m_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
\brief Copy the contents of this splay tree to the given buffer.
|
||||||
|
The elements will be stored in increasing order.
|
||||||
|
*/
|
||||||
void to_buffer(buffer<T> & r) const {
|
void to_buffer(buffer<T> & r) const {
|
||||||
to_buffer(m_ptr, r);
|
to_buffer(m_ptr, r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** \brief (For debugging) Display the content of this splay tree. */
|
||||||
friend std::ostream & operator<<(std::ostream & out, splay_tree const & t) {
|
friend std::ostream & operator<<(std::ostream & out, splay_tree const & t) {
|
||||||
node::display(out, t.m_ptr);
|
node::display(out, t.m_ptr);
|
||||||
return out;
|
return out;
|
||||||
|
@ -274,4 +408,6 @@ public:
|
||||||
};
|
};
|
||||||
template<typename T, typename CMP>
|
template<typename T, typename CMP>
|
||||||
splay_tree<T, CMP> insert(splay_tree<T, CMP> & t, T const & v) { splay_tree<T, CMP> r(t); r.insert(v); return r; }
|
splay_tree<T, CMP> insert(splay_tree<T, CMP> & t, T const & v) { splay_tree<T, CMP> r(t); r.insert(v); return r; }
|
||||||
|
template<typename T, typename CMP>
|
||||||
|
splay_tree<T, CMP> erase(splay_tree<T, CMP> & t, T const & v) { splay_tree<T, CMP> r(t); r.erase(v); return r; }
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue