feat(util/trie): add merge method
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
479685cb97
commit
bfa6193bfe
2 changed files with 52 additions and 0 deletions
|
@ -32,7 +32,32 @@ static void tst1() {
|
||||||
lean_assert(*find(t3, "bd") == 11);
|
lean_assert(*find(t3, "bd") == 11);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void tst2() {
|
||||||
|
ctrie<int> t1;
|
||||||
|
t1 = insert(t1, "hello", 1);
|
||||||
|
t1 = insert(t1, "abc", 2);
|
||||||
|
t1 = insert(t1, "hallo", 3);
|
||||||
|
|
||||||
|
ctrie<int> t2;
|
||||||
|
t2 = insert(t2, "hell", 11);
|
||||||
|
t2 = insert(t2, "abd", 12);
|
||||||
|
t2 = insert(t2, "heaaaaaa", 13);
|
||||||
|
t2 = insert(t2, "hallo", 14);
|
||||||
|
t2 = insert(t2, "abe", 15);
|
||||||
|
|
||||||
|
t1.merge(t2);
|
||||||
|
|
||||||
|
lean_assert(*find(t1, "hallo") == 14);
|
||||||
|
lean_assert(*find(t1, "hello") == 1);
|
||||||
|
lean_assert(*find(t1, "heaaaaaa") == 13);
|
||||||
|
lean_assert(*find(t1, "abc") == 2);
|
||||||
|
lean_assert(*find(t1, "abd") == 12);
|
||||||
|
lean_assert(!find(t2, "abc"));
|
||||||
|
lean_assert(*find(t2, "abd") == 12);
|
||||||
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
tst1();
|
tst1();
|
||||||
|
tst2();
|
||||||
return has_violations() ? 1 : 0;
|
return has_violations() ? 1 : 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,6 +67,22 @@ class trie : public KeyCMP {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static node merge(node && t1, node const & t2) {
|
||||||
|
node new_t1 = ensure_unshared(t1.steal());
|
||||||
|
new_t1->m_value = t2->m_value;
|
||||||
|
t2->m_children.for_each([&](Key const & k, node const & c2) {
|
||||||
|
node const * c1 = new_t1->m_children.find(k);
|
||||||
|
if (c1 == nullptr) {
|
||||||
|
new_t1->m_children.insert(k, c2);
|
||||||
|
} else {
|
||||||
|
node n1(*c1);
|
||||||
|
new_t1->m_children.erase(k);
|
||||||
|
new_t1->m_children.insert(k, merge(n1.steal(), c2));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return new_t1;
|
||||||
|
}
|
||||||
|
|
||||||
node m_root;
|
node m_root;
|
||||||
trie(node const & n):m_root(n) {}
|
trie(node const & n):m_root(n) {}
|
||||||
public:
|
public:
|
||||||
|
@ -100,6 +116,10 @@ public:
|
||||||
else
|
else
|
||||||
return optional<trie>();
|
return optional<trie>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void merge(trie const & t) {
|
||||||
|
m_root = merge(m_root.steal(), t.m_root);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct char_cmp { int operator()(char c1, char c2) const { return c1 < c2 ? -1 : (c1 == c2 ? 0 : 1); } };
|
struct char_cmp { int operator()(char c1, char c2) const { return c1 < c2 ? -1 : (c1 == c2 ? 0 : 1); } };
|
||||||
|
@ -126,4 +146,11 @@ optional<Val> find(ctrie<Val> const & t, std::string const & str) { return t.fin
|
||||||
|
|
||||||
template<typename Val>
|
template<typename Val>
|
||||||
optional<Val> find(ctrie<Val> const & t, char const * str) { return t.find(str, str + strlen(str)); }
|
optional<Val> find(ctrie<Val> const & t, char const * str) { return t.find(str, str + strlen(str)); }
|
||||||
|
|
||||||
|
template<typename Val>
|
||||||
|
inline ctrie<Val> merge(ctrie<Val> const & t1, ctrie<Val> const & t2) {
|
||||||
|
ctrie<Val> r(t1);
|
||||||
|
r.merge(t2);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue