Skip to content

Commit

Permalink
pythongh-126317: Simplify stdlib code by using itertools.batched()
Browse files Browse the repository at this point in the history
  • Loading branch information
dongwooklee96 committed Nov 2, 2024
1 parent f0c6fcc commit 11eb1ee
Showing 1 changed file with 10 additions and 21 deletions.
31 changes: 10 additions & 21 deletions Lib/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from types import FunctionType
from copyreg import dispatch_table
from copyreg import _extension_registry, _inverted_registry, _extension_cache
from itertools import islice
from itertools import batched
from functools import partial
import sys
from sys import maxsize
Expand Down Expand Up @@ -1035,12 +1035,11 @@ def _batch_appends(self, items, obj):

it = iter(items)
start = 0
while True:
tmp = list(islice(it, self._BATCHSIZE))
n = len(tmp)
for batch in batched(it, self._BATCHSIZE):
n = len(batch)
if n > 1:
write(MARK)
for i, x in enumerate(tmp, start):
for i, x in enumerate(batch, start):
try:
save(x)
except BaseException as exc:
Expand All @@ -1049,14 +1048,11 @@ def _batch_appends(self, items, obj):
write(APPENDS)
elif n:
try:
save(tmp[0])
save(batch[0])
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} item {start}')
raise
write(APPEND)
# else tmp is empty, and we're done
if n < self._BATCHSIZE:
return
start += n

def save_dict(self, obj):
Expand Down Expand Up @@ -1087,12 +1083,11 @@ def _batch_setitems(self, items, obj):
return

it = iter(items)
while True:
tmp = list(islice(it, self._BATCHSIZE))
n = len(tmp)
for batch in batched(it, self._BATCHSIZE):
n = len(batch)
if n > 1:
write(MARK)
for k, v in tmp:
for k, v in batch:
save(k)
try:
save(v)
Expand All @@ -1101,17 +1096,14 @@ def _batch_setitems(self, items, obj):
raise
write(SETITEMS)
elif n:
k, v = tmp[0]
k, v = batch[0]
save(k)
try:
save(v)
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} item {k!r}')
raise
write(SETITEM)
# else tmp is empty, and we're done
if n < self._BATCHSIZE:
return

def save_set(self, obj):
save = self.save
Expand All @@ -1125,8 +1117,7 @@ def save_set(self, obj):
self.memoize(obj)

it = iter(obj)
while True:
batch = list(islice(it, self._BATCHSIZE))
for batch in batched(it, self._BATCHSIZE):
n = len(batch)
if n > 0:
write(MARK)
Expand All @@ -1137,8 +1128,6 @@ def save_set(self, obj):
exc.add_note(f'when serializing {_T(obj)} element')
raise
write(ADDITEMS)
if n < self._BATCHSIZE:
return
dispatch[set] = save_set

def save_frozenset(self, obj):
Expand Down

0 comments on commit 11eb1ee

Please sign in to comment.