feat(util/trie): add merge method
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
479685cb97
commit
bfa6193bfe
2 changed files with 52 additions and 0 deletions
|
@ -32,7 +32,32 @@ static void tst1() {
|
|||
lean_assert(*find(t3, "bd") == 11);
|
||||
}
|
||||
|
||||
static void tst2() {
|
||||
ctrie<int> t1;
|
||||
t1 = insert(t1, "hello", 1);
|
||||
t1 = insert(t1, "abc", 2);
|
||||
t1 = insert(t1, "hallo", 3);
|
||||
|
||||
ctrie<int> t2;
|
||||
t2 = insert(t2, "hell", 11);
|
||||
t2 = insert(t2, "abd", 12);
|
||||
t2 = insert(t2, "heaaaaaa", 13);
|
||||
t2 = insert(t2, "hallo", 14);
|
||||
t2 = insert(t2, "abe", 15);
|
||||
|
||||
t1.merge(t2);
|
||||
|
||||
lean_assert(*find(t1, "hallo") == 14);
|
||||
lean_assert(*find(t1, "hello") == 1);
|
||||
lean_assert(*find(t1, "heaaaaaa") == 13);
|
||||
lean_assert(*find(t1, "abc") == 2);
|
||||
lean_assert(*find(t1, "abd") == 12);
|
||||
lean_assert(!find(t2, "abc"));
|
||||
lean_assert(*find(t2, "abd") == 12);
|
||||
}
|
||||
|
||||
int main() {
|
||||
tst1();
|
||||
tst2();
|
||||
return has_violations() ? 1 : 0;
|
||||
}
|
||||
|
|
|
@ -67,6 +67,22 @@ class trie : public KeyCMP {
|
|||
}
|
||||
}
|
||||
|
||||
static node merge(node && t1, node const & t2) {
|
||||
node new_t1 = ensure_unshared(t1.steal());
|
||||
new_t1->m_value = t2->m_value;
|
||||
t2->m_children.for_each([&](Key const & k, node const & c2) {
|
||||
node const * c1 = new_t1->m_children.find(k);
|
||||
if (c1 == nullptr) {
|
||||
new_t1->m_children.insert(k, c2);
|
||||
} else {
|
||||
node n1(*c1);
|
||||
new_t1->m_children.erase(k);
|
||||
new_t1->m_children.insert(k, merge(n1.steal(), c2));
|
||||
}
|
||||
});
|
||||
return new_t1;
|
||||
}
|
||||
|
||||
node m_root;
|
||||
trie(node const & n):m_root(n) {}
|
||||
public:
|
||||
|
@ -100,6 +116,10 @@ public:
|
|||
else
|
||||
return optional<trie>();
|
||||
}
|
||||
|
||||
void merge(trie const & t) {
|
||||
m_root = merge(m_root.steal(), t.m_root);
|
||||
}
|
||||
};
|
||||
|
||||
struct char_cmp { int operator()(char c1, char c2) const { return c1 < c2 ? -1 : (c1 == c2 ? 0 : 1); } };
|
||||
|
@ -126,4 +146,11 @@ optional<Val> find(ctrie<Val> const & t, std::string const & str) { return t.fin
|
|||
|
||||
template<typename Val>
|
||||
optional<Val> find(ctrie<Val> const & t, char const * str) { return t.find(str, str + strlen(str)); }
|
||||
|
||||
template<typename Val>
|
||||
inline ctrie<Val> merge(ctrie<Val> const & t1, ctrie<Val> const & t2) {
|
||||
ctrie<Val> r(t1);
|
||||
r.merge(t2);
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue