From c1fd6da1402ea80baee6308fdfc684a2c935fcd9 Mon Sep 17 00:00:00 2001 From: WitoldFracek Date: Tue, 12 Nov 2024 10:37:33 +0100 Subject: [PATCH] batch_by implemented --- src/qwlist/eager.py | 37 +++++++++++++++++++++ src/qwlist/qwlist.py | 79 +++++++++++++++++++++++++++++++++++++++++++- src/todo.md | 6 ++-- tests/test_eager.py | 22 ++++++++++++ tests/test_lazy.py | 24 +++++++++++++- tests/test_qwlist.py | 22 ++++++++++++ 6 files changed, 185 insertions(+), 5 deletions(-) diff --git a/src/qwlist/eager.py b/src/qwlist/eager.py index 5bf99e7..cb19c73 100644 --- a/src/qwlist/eager.py +++ b/src/qwlist/eager.py @@ -6,6 +6,7 @@ K = TypeVar('K') SupportsLessThan = TypeVar("SupportsLessThan") SupportsAdd = TypeVar("SupportsAdd") +SupportsEq = TypeVar("SupportsEq") Booly = TypeVar('Booly') @@ -242,6 +243,42 @@ def inner(): return EagerQList(inner()) + def batch_by(self, grouper: Callable[[T], SupportsEq]) -> "EagerQList[EagerQList[T]]": + """ + Batches elements of `self` based on the output of the grouper function. Elements are thrown + to the same group as long as the grouper function returns the same key (keys must support equality checks). + When a new key is returned a new batch (group) is created. + + Args: + grouper (Callable[[T], SupportsEq]): function `(T) -> SupportsEq` that provides the keys + used to group elements, where the key type must support equality comparisons. + + Returns: + `EagerQList[EagerQList[T]]` + + Examples: + >>> EagerQList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[0]) + [['a1'], ['b1', 'b2'], ['a2', 'a3'], ['b3']] + >>> EagerQList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[1]) + [['a1', 'b1'], ['b2', 'a2'], ['a3', 'b3']] + """ + def inner(): + if self.len() == 0: + return + batch = EagerQList([self[0]]) + key = grouper(self[0]) + for elem in self[1:]: + new_key = grouper(elem) + if new_key == key: + batch.append(elem) + else: + yield batch + batch = EagerQList([elem]) + key = new_key + if batch: + yield batch + return EagerQList(inner()) + def chain(self, other: Iterable[T]) -> "EagerQList[T]": """ Chains `self` with `other`, returning a new EagerQList with all elements from both iterables. diff --git a/src/qwlist/qwlist.py b/src/qwlist/qwlist.py index 9e09325..fd68c11 100644 --- a/src/qwlist/qwlist.py +++ b/src/qwlist/qwlist.py @@ -4,6 +4,7 @@ K = TypeVar('K') SupportsLessThan = TypeVar("SupportsLessThan") SupportsAdd = TypeVar("SupportsAdd") +SupportsEq = TypeVar("SupportsEq") Booly = TypeVar('Booly') @@ -309,6 +310,45 @@ def inner(): yield group return Lazy(inner()) + def batch_by(self, grouper: Callable[[T], SupportsEq]) -> "Lazy[QList[T]]": + """ + Batches elements of `self` based on the output of the grouper function. Elements are thrown + to the same group as long as the grouper function returns the same key (keys must support equality checks). + When a new key is returned a new batch (group) is created. + + Args: + grouper (Callable[[T], SupportsEq]): function `(T) -> SupportsEq` that provides the keys + used to group elements, where the key type must support equality comparisons. + + Returns: + `Lazy[QList[T]]` + + Examples: + >>> Lazy(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[0]).collect() + [['a1'], ['b1', 'b2'], ['a2', 'a3'], ['b3']] + >>> Lazy(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[1]).collect() + [['a1', 'b1'], ['b2', 'a2'], ['a3', 'b3']] + """ + def inner(): + it = self.iter() + try: + first = next(it) + except StopIteration: + return + batch = QList([first]) + key = grouper(first) + for elem in it: + new_key = grouper(elem) + if new_key == key: + batch.append(elem) + else: + yield batch + batch = QList([elem]) + key = new_key + if batch: + yield batch + return Lazy(inner()) + def chain(self, other: Iterable[T]) -> "Lazy[T]": """ Chains `self` with `other`, returning a new Lazy[T] with all elements from both iterables. @@ -840,6 +880,42 @@ def inner(): yield QList(self[i:i+size]) return Lazy(inner()) + def batch_by(self, grouper: Callable[[T], SupportsEq]) -> "Lazy[QList[T]]": + """ + Batches elements of `self` based on the output of the grouper function. Elements are thrown + to the same group as long as the grouper function returns the same key (keys must support equality checks). + When a new key is returned a new batch (group) is created. + + Args: + grouper (Callable[[T], SupportsEq]): function `(T) -> SupportsEq` that provides the keys + used to group elements, where the key type must support equality comparisons. + + Returns: + `Lazy[QList[T]]` + + Examples: + >>> QList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[0]).collect() + [['a1'], ['b1', 'b2'], ['a2', 'a3'], ['b3']] + >>> QList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[1]).collect() + [['a1', 'b1'], ['b2', 'a2'], ['a3', 'b3']] + """ + def inner(): + if self.len() == 0: + return + batch = QList([self[0]]) + key = grouper(self[0]) + for elem in self[1:]: + new_key = grouper(elem) + if new_key == key: + batch.append(elem) + else: + yield batch + batch = QList([elem]) + key = new_key + if batch: + yield batch + return Lazy(inner()) + def chain(self, other: Iterable[T]) -> Lazy[T]: """ Chains `self` with `other`, returning a Lazy[T] with all elements from both iterables. @@ -1050,5 +1126,6 @@ def naturals(start): .all(lambda x: n % x != 0) )) ) - print(Lazy(['a', 'b', 1]).sum()) + print(QList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[0]).collect()) + print(Lazy([]).batch_by(str).collect()) diff --git a/src/todo.md b/src/todo.md index 190003c..bfc9ce6 100644 --- a/src/todo.md +++ b/src/todo.md @@ -55,9 +55,9 @@ QList, Lazy, EagerQList [x] [x] [x] tested ### batch_by -[ ] [ ] [ ] implemented \ -[ ] [ ] [ ] documented \ -[ ] [ ] [ ] tested +[x] [x] [x] implemented \ +[x] [x] [x] documented \ +[x] [x] [x] tested ### window [ ] [ ] [ ] implemented \ diff --git a/tests/test_eager.py b/tests/test_eager.py index 39cbd51..acac3c9 100644 --- a/tests/test_eager.py +++ b/tests/test_eager.py @@ -340,3 +340,25 @@ def test_sum(): assert EagerQList([1]).sum() == 1 assert EagerQList().sum() is None assert EagerQList(range(4)).fold(lambda acc, x: acc + x, 0) == EagerQList(range(4)).sum() + + +def test_batch_by(): + expected = EagerQList() + res = EagerQList().batch_by(int) + assert res == expected + + expected = EagerQList([[0], [1], [2]]) + res= EagerQList(range(3)).batch_by(lambda x: x) + assert res == expected + + expected = EagerQList([['a1'], ['b1', 'b2'], ['a2', 'a3'], ['b3']]) + res = EagerQList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[0]) + assert res == expected + + expected = EagerQList([['a1', 'b1'], ['b2', 'a2'], ['a3', 'b3']]) + res = EagerQList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[1]) + assert res == expected + + expected = EagerQList([0, 1, 2, 3]) + res = EagerQList(range(4)).batch_by(lambda x: True) + assert res == expected diff --git a/tests/test_lazy.py b/tests/test_lazy.py index c7b3957..f10ed4e 100644 --- a/tests/test_lazy.py +++ b/tests/test_lazy.py @@ -408,4 +408,26 @@ def test_sum(): assert Lazy([1]).sum() == 1 assert Lazy([]).sum() is None assert Lazy(range(3)).map(str).sum() == '012' - assert Lazy(range(4)).fold(lambda acc, x: acc + x, 0) == Lazy(range(4)).sum() \ No newline at end of file + assert Lazy(range(4)).fold(lambda acc, x: acc + x, 0) == Lazy(range(4)).sum() + + +def test_batch_by(): + expected = QList() + res = Lazy([]).batch_by(int).collect() + assert res == expected + + expected = QList([[0], [1], [2]]) + res = Lazy(range(3)).batch_by(lambda x: x).collect() + assert res == expected + + expected = QList([['a1'], ['b1', 'b2'], ['a2', 'a3'], ['b3']]) + res = Lazy(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[0]).collect() + assert res == expected + + expected = QList([['a1', 'b1'], ['b2', 'a2'], ['a3', 'b3']]) + res = Lazy(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[1]).collect() + assert res == expected + + expected = QList([0, 1, 2, 3]) + res = Lazy(range(4)).batch_by(lambda x: True).collect() + assert res == expected \ No newline at end of file diff --git a/tests/test_qwlist.py b/tests/test_qwlist.py index 31fbbfa..7ca8a08 100644 --- a/tests/test_qwlist.py +++ b/tests/test_qwlist.py @@ -449,3 +449,25 @@ def test_sum(): assert QList().sum() is None assert QList(range(4)).fold(lambda acc, x: acc + x, 0) == QList(range(4)).sum() + +def test_batch_by(): + expected = QList() + res = QList().batch_by(int).collect() + assert res == expected + + expected = QList([[0], [1], [2]]) + res = QList(range(3)).batch_by(lambda x: x).collect() + assert res == expected + + expected = QList([['a1'], ['b1', 'b2'], ['a2', 'a3'], ['b3']]) + res = QList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[0]).collect() + assert res == expected + + expected = QList([['a1', 'b1'], ['b2', 'a2'], ['a3', 'b3']]) + res = QList(['a1', 'b1', 'b2', 'a2', 'a3', 'b3']).batch_by(lambda s: s[1]).collect() + assert res == expected + + expected = QList([0, 1, 2, 3]) + res = QList(range(4)).batch_by(lambda x: True).collect() + assert res == expected +