Skip to content

Commit

Permalink
Add dichotomize function for splitting a sequence by a predicate
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Jevnik committed Aug 2, 2017
1 parent c3a6294 commit 86a38b3
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 2 deletions.
1 change: 1 addition & 0 deletions doc/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Itertoolz
concatv
cons
count
dichotomize
diff
drop
first
Expand Down
1 change: 1 addition & 0 deletions toolz/curried/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
assoc_in = toolz.curry(toolz.assoc_in)
cons = toolz.curry(toolz.cons)
countby = toolz.curry(toolz.countby)
dichotomize = toolz.curry(toolz.dichotomize)
do = toolz.curry(toolz.do)
drop = toolz.curry(toolz.drop)
excepts = toolz.curry(toolz.excepts)
Expand Down
48 changes: 47 additions & 1 deletion toolz/itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
'first', 'second', 'nth', 'last', 'get', 'concat', 'concatv',
'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', 'iterate',
'sliding_window', 'partition', 'partition_all', 'count', 'pluck',
'join', 'tail', 'diff', 'topk', 'peek', 'random_sample')
'join', 'tail', 'diff', 'topk', 'peek', 'random_sample',
'dichotomize')


def remove(predicate, seq):
Expand Down Expand Up @@ -980,3 +981,48 @@ def random_sample(prob, seq, random_state=None):
if not hasattr(random_state, 'random'):
random_state = Random(random_state)
return filter(lambda _: random_state.random() < prob, seq)


def _complement_iterator(it, predicate, our_queue, other_queue):
for element in our_queue:
yield element
our_queue.clear()

for element in it:
if predicate(element):
yield element
else:
other_queue.append(element)

for element in our_queue:
yield element
our_queue.clear()


def dichotomize(predicate, iterable):
"""Take a predicate and an iterable and return the pair of iterables of
elements which do and do not satisfy the predicate. The resulting iterators
are lazy.
>>> def even(n):
... return n & 1 == 0
...
>>> evens, odds = dichotomize(even, range(10))
>>> list(evens)
[0, 2, 4, 6, 8]
>>> list(odds)
[1, 3, 5, 7, 9]
"""
true_queue = collections.deque()
false_queue = collections.deque()
it = iter(iterable)

return (
_complement_iterator(it, predicate, true_queue, false_queue),
_complement_iterator(
it,
lambda element: not predicate(element),
false_queue,
true_queue,
),
)
29 changes: 28 additions & 1 deletion toolz/tests/test_itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
reduceby, iterate, accumulate,
sliding_window, count, partition,
partition_all, take_nth, pluck, join,
diff, topk, peek, random_sample)
diff, topk, peek, random_sample, dichotomize)
from toolz.compatibility import range, filter
from operator import add, mul

Expand Down Expand Up @@ -524,3 +524,30 @@ def test_random_sample():
assert mk_rsample(b"a") == mk_rsample(u"a")

assert raises(TypeError, lambda: mk_rsample([]))


def test_dichotimize():
evens, odds = dichotomize(iseven, range(10))
assert list(evens) == [0, 2, 4, 6, 8]
assert list(odds) == [1, 3, 5, 7, 9]


def test_dichotimize_interleaved_next_calls():
evens, odds = dichotomize(iseven, range(10))

assert next(evens) == 0
assert next(evens) == 2

assert next(odds) == 1
assert next(odds) == 3
assert next(odds) == 5

assert next(evens) == 4
assert next(evens) == 6
assert next(evens) == 8

assert next(odds) == 7
assert next(odds) == 9

assert list(evens) == []
assert list(odds) == []

0 comments on commit 86a38b3

Please sign in to comment.