From bfa6193bfe1ce2cceef0dbf0a15b443141768b94 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 13 May 2014 17:59:20 -0700 Subject: [PATCH] feat(util/trie): add merge method Signed-off-by: Leonardo de Moura --- src/tests/util/trie.cpp | 25 +++++++++++++++++++++++++ src/util/trie.h | 27 +++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/src/tests/util/trie.cpp b/src/tests/util/trie.cpp index be4a90e92..671ea0329 100644 --- a/src/tests/util/trie.cpp +++ b/src/tests/util/trie.cpp @@ -32,7 +32,32 @@ static void tst1() { lean_assert(*find(t3, "bd") == 11); } +static void tst2() { + ctrie t1; + t1 = insert(t1, "hello", 1); + t1 = insert(t1, "abc", 2); + t1 = insert(t1, "hallo", 3); + + ctrie t2; + t2 = insert(t2, "hell", 11); + t2 = insert(t2, "abd", 12); + t2 = insert(t2, "heaaaaaa", 13); + t2 = insert(t2, "hallo", 14); + t2 = insert(t2, "abe", 15); + + t1.merge(t2); + + lean_assert(*find(t1, "hallo") == 14); + lean_assert(*find(t1, "hello") == 1); + lean_assert(*find(t1, "heaaaaaa") == 13); + lean_assert(*find(t1, "abc") == 2); + lean_assert(*find(t1, "abd") == 12); + lean_assert(!find(t2, "abc")); + lean_assert(*find(t2, "abd") == 12); +} + int main() { tst1(); + tst2(); return has_violations() ? 1 : 0; } diff --git a/src/util/trie.h b/src/util/trie.h index 006dab3ec..e6397e321 100644 --- a/src/util/trie.h +++ b/src/util/trie.h @@ -67,6 +67,22 @@ class trie : public KeyCMP { } } + static node merge(node && t1, node const & t2) { + node new_t1 = ensure_unshared(t1.steal()); + new_t1->m_value = t2->m_value; + t2->m_children.for_each([&](Key const & k, node const & c2) { + node const * c1 = new_t1->m_children.find(k); + if (c1 == nullptr) { + new_t1->m_children.insert(k, c2); + } else { + node n1(*c1); + new_t1->m_children.erase(k); + new_t1->m_children.insert(k, merge(n1.steal(), c2)); + } + }); + return new_t1; + } + node m_root; trie(node const & n):m_root(n) {} public: @@ -100,6 +116,10 @@ public: else return optional(); } + + void merge(trie const & t) { + m_root = merge(m_root.steal(), t.m_root); + } }; struct char_cmp { int operator()(char c1, char c2) const { return c1 < c2 ? -1 : (c1 == c2 ? 0 : 1); } }; @@ -126,4 +146,11 @@ optional find(ctrie const & t, std::string const & str) { return t.fin template optional find(ctrie const & t, char const * str) { return t.find(str, str + strlen(str)); } + +template +inline ctrie merge(ctrie const & t1, ctrie const & t2) { + ctrie r(t1); + r.merge(t2); + return r; +} }