Skip to content

Commit

Permalink
batch_by implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
WitoldFracek committed Nov 12, 2024
1 parent 1de62d1 commit c1fd6da
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 5 deletions.
37 changes: 37 additions & 0 deletions src/qwlist/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
K = TypeVar('K')
SupportsLessThan = TypeVar("SupportsLessThan")
SupportsAdd = TypeVar("SupportsAdd")
SupportsEq = TypeVar("SupportsEq")
Booly = TypeVar('Booly')


Expand Down Expand Up @@ -242,6 +243,42 @@ def inner():

return EagerQList(inner())

def batch_by(self, grouper: Callable[[T], SupportsEq]) -> "EagerQList[EagerQList[T]]":
"""
Batches elements of `self` based on the output of the grouper function. Elements are thrown
to the same group as long as the grouper function returns the same key (keys must support equality checks).
When a new key is returned a new batch (group) is created.
Args:
grouper (Callable[[T], SupportsEq]): function `(T) -> SupportsEq` that provides the keys
used to group elements, where the key type must support equality comparisons.
Returns:
`EagerQList[EagerQList[T]]`
Examples:
>>> EagerQList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[0])
[['a1'], ['b1', 'b2'], ['a2', 'a3'], ['b3']]
>>> EagerQList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[1])
[['a1', 'b1'], ['b2', 'a2'], ['a3', 'b3']]
"""
def inner():
if self.len() == 0:
return
batch = EagerQList([self[0]])
key = grouper(self[0])
for elem in self[1:]:
new_key = grouper(elem)
if new_key == key:
batch.append(elem)
else:
yield batch
batch = EagerQList([elem])
key = new_key
if batch:
yield batch
return EagerQList(inner())

def chain(self, other: Iterable[T]) -> "EagerQList[T]":
"""
Chains `self` with `other`, returning a new EagerQList with all elements from both iterables.
Expand Down
79 changes: 78 additions & 1 deletion src/qwlist/qwlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
K = TypeVar('K')
SupportsLessThan = TypeVar("SupportsLessThan")
SupportsAdd = TypeVar("SupportsAdd")
SupportsEq = TypeVar("SupportsEq")
Booly = TypeVar('Booly')


Expand Down Expand Up @@ -309,6 +310,45 @@ def inner():
yield group
return Lazy(inner())

def batch_by(self, grouper: Callable[[T], SupportsEq]) -> "Lazy[QList[T]]":
"""
Batches elements of `self` based on the output of the grouper function. Elements are thrown
to the same group as long as the grouper function returns the same key (keys must support equality checks).
When a new key is returned a new batch (group) is created.
Args:
grouper (Callable[[T], SupportsEq]): function `(T) -> SupportsEq` that provides the keys
used to group elements, where the key type must support equality comparisons.
Returns:
`Lazy[QList[T]]`
Examples:
>>> Lazy(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[0]).collect()
[['a1'], ['b1', 'b2'], ['a2', 'a3'], ['b3']]
>>> Lazy(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[1]).collect()
[['a1', 'b1'], ['b2', 'a2'], ['a3', 'b3']]
"""
def inner():
it = self.iter()
try:
first = next(it)
except StopIteration:
return
batch = QList([first])
key = grouper(first)
for elem in it:
new_key = grouper(elem)
if new_key == key:
batch.append(elem)
else:
yield batch
batch = QList([elem])
key = new_key
if batch:
yield batch
return Lazy(inner())

def chain(self, other: Iterable[T]) -> "Lazy[T]":
"""
Chains `self` with `other`, returning a new Lazy[T] with all elements from both iterables.
Expand Down Expand Up @@ -840,6 +880,42 @@ def inner():
yield QList(self[i:i+size])
return Lazy(inner())

def batch_by(self, grouper: Callable[[T], SupportsEq]) -> "Lazy[QList[T]]":
"""
Batches elements of `self` based on the output of the grouper function. Elements are thrown
to the same group as long as the grouper function returns the same key (keys must support equality checks).
When a new key is returned a new batch (group) is created.
Args:
grouper (Callable[[T], SupportsEq]): function `(T) -> SupportsEq` that provides the keys
used to group elements, where the key type must support equality comparisons.
Returns:
`Lazy[QList[T]]`
Examples:
>>> QList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[0]).collect()
[['a1'], ['b1', 'b2'], ['a2', 'a3'], ['b3']]
>>> QList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[1]).collect()
[['a1', 'b1'], ['b2', 'a2'], ['a3', 'b3']]
"""
def inner():
if self.len() == 0:
return
batch = QList([self[0]])
key = grouper(self[0])
for elem in self[1:]:
new_key = grouper(elem)
if new_key == key:
batch.append(elem)
else:
yield batch
batch = QList([elem])
key = new_key
if batch:
yield batch
return Lazy(inner())

def chain(self, other: Iterable[T]) -> Lazy[T]:
"""
Chains `self` with `other`, returning a Lazy[T] with all elements from both iterables.
Expand Down Expand Up @@ -1050,5 +1126,6 @@ def naturals(start):
.all(lambda x: n % x != 0)
))
)
print(Lazy(['a', 'b', 1]).sum())
print(QList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[0]).collect())
print(Lazy([]).batch_by(str).collect())

6 changes: 3 additions & 3 deletions src/todo.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ QList, Lazy, EagerQList
[x] [x] [x] tested

### batch_by
[ ] [ ] [ ] implemented \
[ ] [ ] [ ] documented \
[ ] [ ] [ ] tested
[x] [x] [x] implemented \
[x] [x] [x] documented \
[x] [x] [x] tested

### window
[ ] [ ] [ ] implemented \
Expand Down
22 changes: 22 additions & 0 deletions tests/test_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,25 @@ def test_sum():
assert EagerQList([1]).sum() == 1
assert EagerQList().sum() is None
assert EagerQList(range(4)).fold(lambda acc, x: acc + x, 0) == EagerQList(range(4)).sum()


def test_batch_by():
expected = EagerQList()
res = EagerQList().batch_by(int)
assert res == expected

expected = EagerQList([[0], [1], [2]])
res= EagerQList(range(3)).batch_by(lambda x: x)
assert res == expected

expected = EagerQList([['a1'], ['b1', 'b2'], ['a2', 'a3'], ['b3']])
res = EagerQList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[0])
assert res == expected

expected = EagerQList([['a1', 'b1'], ['b2', 'a2'], ['a3', 'b3']])
res = EagerQList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[1])
assert res == expected

expected = EagerQList([0, 1, 2, 3])
res = EagerQList(range(4)).batch_by(lambda x: True)
assert res == expected
24 changes: 23 additions & 1 deletion tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,4 +408,26 @@ def test_sum():
assert Lazy([1]).sum() == 1
assert Lazy([]).sum() is None
assert Lazy(range(3)).map(str).sum() == '012'
assert Lazy(range(4)).fold(lambda acc, x: acc + x, 0) == Lazy(range(4)).sum()
assert Lazy(range(4)).fold(lambda acc, x: acc + x, 0) == Lazy(range(4)).sum()


def test_batch_by():
expected = QList()
res = Lazy([]).batch_by(int).collect()
assert res == expected

expected = QList([[0], [1], [2]])
res = Lazy(range(3)).batch_by(lambda x: x).collect()
assert res == expected

expected = QList([['a1'], ['b1', 'b2'], ['a2', 'a3'], ['b3']])
res = Lazy(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[0]).collect()
assert res == expected

expected = QList([['a1', 'b1'], ['b2', 'a2'], ['a3', 'b3']])
res = Lazy(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[1]).collect()
assert res == expected

expected = QList([0, 1, 2, 3])
res = Lazy(range(4)).batch_by(lambda x: True).collect()
assert res == expected
22 changes: 22 additions & 0 deletions tests/test_qwlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,25 @@ def test_sum():
assert QList().sum() is None
assert QList(range(4)).fold(lambda acc, x: acc + x, 0) == QList(range(4)).sum()


def test_batch_by():
expected = QList()
res = QList().batch_by(int).collect()
assert res == expected

expected = QList([[0], [1], [2]])
res = QList(range(3)).batch_by(lambda x: x).collect()
assert res == expected

expected = QList([['a1'], ['b1', 'b2'], ['a2', 'a3'], ['b3']])
res = QList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[0]).collect()
assert res == expected

expected = QList([['a1', 'b1'], ['b2', 'a2'], ['a3', 'b3']])
res = QList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[1]).collect()
assert res == expected

expected = QList([0, 1, 2, 3])
res = QList(range(4)).batch_by(lambda x: True).collect()
assert res == expected

0 comments on commit c1fd6da

Please sign in to comment.