diff --git a/doc/source/api.rst b/doc/source/api.rst index f53ff19d..347ccb59 100644 --- a/doc/source/api.rst +++ b/doc/source/api.rst @@ -15,6 +15,7 @@ Itertoolz concatv cons count + dichotomize diff drop first diff --git a/toolz/itertoolz.py b/toolz/itertoolz.py index a25eea3c..8e0f1b85 100644 --- a/toolz/itertoolz.py +++ b/toolz/itertoolz.py @@ -14,7 +14,8 @@ 'first', 'second', 'nth', 'last', 'get', 'concat', 'concatv', 'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', 'iterate', 'sliding_window', 'partition', 'partition_all', 'count', 'pluck', - 'join', 'tail', 'diff', 'topk', 'peek', 'random_sample') + 'join', 'tail', 'diff', 'topk', 'peek', 'random_sample', + 'dichotomize') def remove(predicate, seq): @@ -980,3 +981,48 @@ def random_sample(prob, seq, random_state=None): if not hasattr(random_state, 'random'): random_state = Random(random_state) return filter(lambda _: random_state.random() < prob, seq) + + +def _complement_iterator(it, predicate, our_queue, other_queue): + for element in our_queue: + yield element + our_queue.clear() + + for element in it: + if predicate(element): + yield element + else: + other_queue.append(element) + + for element in our_queue: + yield element + our_queue.clear() + + +def dichotomize(predicate, iterable): + """Take a predicate and an iterable and return the pair of iterables of + elements which do and do not satisfy the predicate. The resulting iterators + are lazy. + + >>> def even(n): + ... return n & 1 == 0 + ... + >>> evens, odds = dichotomize(even, range(10)) + >>> list(evens) + [0, 2, 4, 6, 8] + >>> list(odds) + [1, 3, 5, 7, 9] + """ + true_queue = collections.deque() + false_queue = collections.deque() + it = iter(iterable) + + return ( + _complement_iterator(it, predicate, true_queue, false_queue), + _complement_iterator( + it, + lambda element: not predicate(element), + false_queue, + true_queue, + ), + ) diff --git a/toolz/tests/test_itertoolz.py b/toolz/tests/test_itertoolz.py index 93aa856d..e2021ee5 100644 --- a/toolz/tests/test_itertoolz.py +++ b/toolz/tests/test_itertoolz.py @@ -13,7 +13,7 @@ reduceby, iterate, accumulate, sliding_window, count, partition, partition_all, take_nth, pluck, join, - diff, topk, peek, random_sample) + diff, topk, peek, random_sample, dichotomize) from toolz.compatibility import range, filter from operator import add, mul @@ -524,3 +524,30 @@ def test_random_sample(): assert mk_rsample(b"a") == mk_rsample(u"a") assert raises(TypeError, lambda: mk_rsample([])) + + +def test_dichotimize(): + evens, odds = dichotomize(iseven, range(10)) + assert list(evens) == [0, 2, 4, 6, 8] + assert list(odds) == [1, 3, 5, 7, 9] + + +def test_dichotimize_interleaved_next_calls(): + evens, odds = dichotomize(iseven, range(10)) + + assert next(evens) == 0 + assert next(evens) == 2 + + assert next(odds) == 1 + assert next(odds) == 3 + assert next(odds) == 5 + + assert next(evens) == 4 + assert next(evens) == 6 + assert next(evens) == 8 + + assert next(odds) == 7 + assert next(odds) == 9 + + assert list(evens) == [] + assert list(odds) == []