diff --git a/src/tests/util/lazy_list.cpp b/src/tests/util/lazy_list.cpp index 19e8e549f..ab76ccc2d 100644 --- a/src/tests/util/lazy_list.cpp +++ b/src/tests/util/lazy_list.cpp @@ -60,7 +60,6 @@ lazy_list loop() { }); } - template void display(lazy_list const & l) { int buffer[20000]; @@ -116,8 +115,63 @@ void tst2() { lean_assert(i == 2); } +static void check(lazy_list const & l, list expected) { + display(l); + for_each(l, [&](int v) { + lean_assert(expected); + lean_assert(v == head(expected)); + expected = tail(expected); + }); + lean_assert(!expected); +} + +static void tst3() { + check(repeat(1, [](int v) { if (v > 5) return lazy_list(); else return lazy_list(v+1); }), + list(6)); + // The following repeat produces the following "execution trace". + // We use >> << to mark the element being processed + // { >>1<< } 1 produces {2, 4} + // { >>2<<, 4 } 2 produces {3, 6} + // { >>3<<, 6, 4 } 3 produces {4, 8} + // { >>4<<, 8, 6, 4 } 4 produces {5, 10} + // { >>5<<, 10, 8, 6, 4 } 5 produces {6, 12} + // { 6, 12, 10, 8, 6, >>4<< } skips 6, 12, 10, 8, 6 since they are bigger than 5, and 4 produces {5, 10} + // { 6, 12, 10, 8, 6, >>5<<, 10 } 5 produces {6, 12} + // { 6, 12, 10, 8, 6, 6, 12, 10 } skips 6, 12, 10 since they are bigger than 5 + check(repeat(1, [](int v) { if (v > 5) return lazy_list(); else return from(v+1, v+1, 2*(v + 1)); }), + list({6, 12, 10, 8, 6, 6, 12, 10})); +} + +static void tst4() { + // We use v:k to denote value v was produced with the given k. + // When k == 0, the element is not processed. + // Here is the execution trace for the following repeat_at_most + // { 1:4 } + // { 1:3, 2:3 } + // { 1:2, 2:2, 2:3 } + // { 1:1, 2:1, 2:2, 2:3 } + // { 1:0, 2:0, 2:1, 2:2, 2:3 } + // { 1:0, 2:0, 2:0, 4:0, 2:2, 2:3 } + // { 1:0, 2:0, 2:0, 4:0, 2:1, 4:1, 2:3 } + // { 1:0, 2:0, 2:0, 4:0, 2:0, 4:0, 4:1, 2:3 } + // { 1:0, 2:0, 2:0, 4:0, 2:0, 4:0, 4:0, 8:0, 2:3 } + // { 1:0, 2:0, 2:0, 4:0, 2:0, 4:0, 4:0, 8:0, 2:2, 4:2 } + // { 1:0, 2:0, 2:0, 4:0, 2:0, 4:0, 4:0, 8:0, 2:1, 4:1, 4:2 } + // { 1:0, 2:0, 2:0, 4:0, 2:0, 4:0, 4:0, 8:0, 2:0, 4:0, 4:1, 4:2 } + // { 1:0, 2:0, 2:0, 4:0, 2:0, 4:0, 4:0, 8:0, 2:0, 4:0, 4:0, 8:0, 4:2 } + // { 1:0, 2:0, 2:0, 4:0, 2:0, 4:0, 4:0, 8:0, 2:0, 4:0, 4:0, 8:0, 4:1, 8:1 } + // { 1:0, 2:0, 2:0, 4:0, 2:0, 4:0, 4:0, 8:0, 2:0, 4:0, 4:0, 8:0, 4:0, 8:0, 8:1 } + // { 1:0, 2:0, 2:0, 4:0, 2:0, 4:0, 4:0, 8:0, 2:0, 4:0, 4:0, 8:0, 4:0, 8:0, 8:0, 16:0 } + // Thus, the final lazy list is + // { 1, 2, 2, 4, 2, 4, 4, 8, 2, 4, 4, 8, 4, 8, 8, 16 } + check(repeat_at_most(1, [](int v) { return from(v, v, 2*v); }, 4), + list({ 1, 2, 2, 4, 2, 4, 4, 8, 2, 4, 4, 8, 4, 8, 8, 16 })); +} + int main() { tst1(); tst2(); + tst3(); + tst4(); return has_violations() ? 1 : 0; } diff --git a/src/util/lazy_list_fn.h b/src/util/lazy_list_fn.h index 62af156db..c08c38fec 100644 --- a/src/util/lazy_list_fn.h +++ b/src/util/lazy_list_fn.h @@ -173,6 +173,38 @@ lazy_list map_append(lazy_list const & l, F && f) { return map_append_aux(lazy_list(), l, f); } +template +lazy_list repeat(T const & v, F && f) { + return mk_lazy_list([=]() { + auto p = f(v).pull(); + if (!p) { + return some(mk_pair(v, lazy_list())); + } else { + check_interrupted(); + return append(repeat(p->first, f), + map_append(p->second, [=](T const & v2) { return repeat(v2, f); })).pull(); + } + }); +} + +template +lazy_list repeat_at_most(T const & v, F && f, unsigned k) { + return mk_lazy_list([=]() { + if (k == 0) { + return some(mk_pair(v, lazy_list())); + } else { + auto p = f(v).pull(); + if (!p) { + return some(mk_pair(v, lazy_list())); + } else { + check_interrupted(); + return append(repeat_at_most(p->first, f, k - 1), + map_append(p->second, [=](T const & v2) { return repeat_at_most(v2, f, k - 1); })).pull(); + } + } + }); +} + /** \brief Return a lazy list such that only the elements that can be computed in less than \c ms milliseconds are kept. That is, it uses a timeout for the \c pull