From 54e63fd4dede6e80b91fda12519bbf5fb5e57c6a Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 26 Sep 2013 16:55:48 -0700 Subject: [PATCH] feat(splay_tree): add fold and for_each templates for splay_tree and splay_map Signed-off-by: Leonardo de Moura --- src/tests/util/splay_map.cpp | 13 +++++++++ src/tests/util/splay_tree.cpp | 21 ++++++++++++--- src/util/splay_map.h | 32 ++++++++++++++++++++++ src/util/splay_tree.h | 50 ++++++++++++++++++++++++++++++----- 4 files changed, 107 insertions(+), 9 deletions(-) diff --git a/src/tests/util/splay_map.cpp b/src/tests/util/splay_map.cpp index 2b6e08851..401b7e0b2 100644 --- a/src/tests/util/splay_map.cpp +++ b/src/tests/util/splay_map.cpp @@ -5,6 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include +#include #include "util/test.h" #include "util/splay_map.h" #include "util/name.h" @@ -31,7 +32,19 @@ static void tst0() { lean_assert(m2.size() == 3); } +static void tst1() { + int2name m; + m[10] = name("t1"); + m[20] = name("t2"); + lean_assert(fold(m, [](int k, name const &, int a) { return k + a; }, 0) == 30); + std::ostringstream out; + for_each(m, [&](int, name const & v) { out << v << " "; }); + std::cout << out.str() << "\n"; + lean_assert(out.str() == "t1 t2 "); +} + int main() { tst0(); + tst1(); return has_violations() ? 1 : 0; } diff --git a/src/tests/util/splay_tree.cpp b/src/tests/util/splay_tree.cpp index ebd74173d..1293a76ee 100644 --- a/src/tests/util/splay_tree.cpp +++ b/src/tests/util/splay_tree.cpp @@ -9,6 +9,7 @@ Author: Leonardo de Moura #include #include #include +#include #include "util/test.h" #include "util/splay_tree.h" #include "util/timeit.h" @@ -19,7 +20,7 @@ struct int_lt { int operator()(int i1, int i2) const { return i1 < i2 ? -1 : (i1 typedef splay_tree int_splay_tree; typedef std::unordered_set int_set; -void tst0() { +static void tst0() { int_splay_tree s; s.insert(10); s.insert(11); @@ -32,7 +33,7 @@ void tst0() { s.insert(15); } -void tst1() { +static void tst1() { int_splay_tree s; s.insert(10); s.insert(3); @@ -73,7 +74,7 @@ void tst1() { lean_assert(s.empty()); } -bool operator==(int_set const & v1, int_splay_tree const & v2) { +static bool operator==(int_set const & v1, int_splay_tree const & v2) { buffer b; // std::cout << v2 << "\n"; // std::for_each(v1.begin(), v1.end(), [](int v) { std::cout << v << " "; }); std::cout << "\n"; @@ -141,9 +142,23 @@ static void tst2() { driver(128, 1000, 10000, 0.5, 0.01); } +static void tst3() { + int_splay_tree v; + v.insert(10); + v.insert(5); + v.insert(1); + v.insert(3); + lean_assert_eq(fold(v, [](int a, int b) { return a + b; }, 0), 19); + std::ostringstream out; + for_each(v, [&](int a) { out << a << " "; }); + std::cout << out.str() << "\n"; + lean_assert(out.str() == "1 3 5 10 "); +} + int main() { tst0(); tst1(); tst2(); + tst3(); return has_violations() ? 1 : 0; } diff --git a/src/util/splay_map.h b/src/util/splay_map.h index d7b317b22..a30c8ffa6 100644 --- a/src/util/splay_map.h +++ b/src/util/splay_map.h @@ -57,5 +57,37 @@ public: performing an insertion if such key does not already exist. */ ref operator[](K const & k) { return ref(*this, k); } + + template + R fold(F f, R r) const { + auto f_prime = [&](entry const & e, R r) -> R { return f(e.first, e.second, r); }; + return m_map.fold(f_prime, r); + } + + template + void for_each(F f) const { + auto f_prime = [&](entry const & e) { f(e.first, e.second); }; + return m_map.for_each(f_prime); + } }; +template +splay_map insert(splay_map const & m, K const & k, T const & v) { + auto r = m; + r.insert(k, v); + return r; +} +template +splay_map erase(splay_map const & m, K const & k) { + auto r = m; + r.erase(k); + return r; +} +template +R fold(splay_map const & m, F f, R r) { + return m.fold(f, r); +} +template +void for_each(splay_map const & m, F f) { + return m.for_each(f); +} } diff --git a/src/util/splay_tree.h b/src/util/splay_tree.h index 2a595931b..04565c380 100644 --- a/src/util/splay_tree.h +++ b/src/util/splay_tree.h @@ -287,11 +287,24 @@ class splay_tree : public CMP { } } - static unsigned size(node const * n) { - if (n) - return 1 + size(n->m_left) + size(n->m_right); - else - return 0; + template + static R fold(node const * n, F & f, R r) { + if (n) { + r = fold(n->m_left, f, r); + r = f(n->m_value, r); + return fold(n->m_right, f, r); + } else { + return r; + } + } + + template + static void for_each(node const * n, F & f) { + if (n) { + for_each(n->m_left, f); + f(n->m_value); + for_each(n->m_right, f); + } } splay_tree(splay_tree const & s, node * new_root):CMP(s), m_ptr(new_root) { node::inc_ref(m_ptr); } @@ -319,7 +332,7 @@ public: bool is_eqp(splay_tree const & t) const { return m_ptr == t.m_ptr; } /** \brief Return the size of the splay tree */ - unsigned size() const { return size(m_ptr); } + unsigned size() const { return fold([](T const &, unsigned a) { return a + 1; }, 0); } /** \brief Insert \c v in this splay tree. */ void insert(T const & v) { @@ -405,9 +418,34 @@ public: node::display(out, t.m_ptr); return out; } + + /** + \brief Return f(a_k, ..., f(a_1, f(a_0, r)) ...), where + a_0, a_1, ... a_k are the elements is stored in the splay tree. + */ + template + R fold(F f, R r) const { + return fold(m_ptr, f, r); + } + + /** + \brief Apply \c f to each value stored in the splay tree. + */ + template + void for_each(F f) const { + for_each(m_ptr, f); + } }; template splay_tree insert(splay_tree & t, T const & v) { splay_tree r(t); r.insert(v); return r; } template splay_tree erase(splay_tree & t, T const & v) { splay_tree r(t); r.erase(v); return r; } +template +R fold(splay_tree const & t, F f, R r) { + return t.fold(f, r); +} +template +void for_each(splay_tree const & t, F f) { + return t.for_each(f); +} }