feat(util/trie): add merge method

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-05-13 17:59:20 -07:00
parent 479685cb97
commit bfa6193bfe
2 changed files with 52 additions and 0 deletions

View file

@ -32,7 +32,32 @@ static void tst1() {
lean_assert(*find(t3, "bd") == 11); lean_assert(*find(t3, "bd") == 11);
} }
static void tst2() {
ctrie<int> t1;
t1 = insert(t1, "hello", 1);
t1 = insert(t1, "abc", 2);
t1 = insert(t1, "hallo", 3);
ctrie<int> t2;
t2 = insert(t2, "hell", 11);
t2 = insert(t2, "abd", 12);
t2 = insert(t2, "heaaaaaa", 13);
t2 = insert(t2, "hallo", 14);
t2 = insert(t2, "abe", 15);
t1.merge(t2);
lean_assert(*find(t1, "hallo") == 14);
lean_assert(*find(t1, "hello") == 1);
lean_assert(*find(t1, "heaaaaaa") == 13);
lean_assert(*find(t1, "abc") == 2);
lean_assert(*find(t1, "abd") == 12);
lean_assert(!find(t2, "abc"));
lean_assert(*find(t2, "abd") == 12);
}
int main() { int main() {
tst1(); tst1();
tst2();
return has_violations() ? 1 : 0; return has_violations() ? 1 : 0;
} }

View file

@ -67,6 +67,22 @@ class trie : public KeyCMP {
} }
} }
static node merge(node && t1, node const & t2) {
node new_t1 = ensure_unshared(t1.steal());
new_t1->m_value = t2->m_value;
t2->m_children.for_each([&](Key const & k, node const & c2) {
node const * c1 = new_t1->m_children.find(k);
if (c1 == nullptr) {
new_t1->m_children.insert(k, c2);
} else {
node n1(*c1);
new_t1->m_children.erase(k);
new_t1->m_children.insert(k, merge(n1.steal(), c2));
}
});
return new_t1;
}
node m_root; node m_root;
trie(node const & n):m_root(n) {} trie(node const & n):m_root(n) {}
public: public:
@ -100,6 +116,10 @@ public:
else else
return optional<trie>(); return optional<trie>();
} }
void merge(trie const & t) {
m_root = merge(m_root.steal(), t.m_root);
}
}; };
struct char_cmp { int operator()(char c1, char c2) const { return c1 < c2 ? -1 : (c1 == c2 ? 0 : 1); } }; struct char_cmp { int operator()(char c1, char c2) const { return c1 < c2 ? -1 : (c1 == c2 ? 0 : 1); } };
@ -126,4 +146,11 @@ optional<Val> find(ctrie<Val> const & t, std::string const & str) { return t.fin
template<typename Val> template<typename Val>
optional<Val> find(ctrie<Val> const & t, char const * str) { return t.find(str, str + strlen(str)); } optional<Val> find(ctrie<Val> const & t, char const * str) { return t.find(str, str + strlen(str)); }
template<typename Val>
inline ctrie<Val> merge(ctrie<Val> const & t1, ctrie<Val> const & t2) {
ctrie<Val> r(t1);
r.merge(t2);
return r;
}
} }