Skip to content

Commit

Permalink
pythongh-126317: Simplify stdlib code by using itertools.batched()
Browse files Browse the repository at this point in the history
  • Loading branch information
dongwooklee96 committed Nov 2, 2024
1 parent f0c6fcc commit 4fa459b
Showing 1 changed file with 21 additions and 39 deletions.
60 changes: 21 additions & 39 deletions Lib/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from types import FunctionType
from copyreg import dispatch_table
from copyreg import _extension_registry, _inverted_registry, _extension_cache
from itertools import islice
from itertools import batched
from functools import partial
import sys
from sys import maxsize
Expand Down Expand Up @@ -1033,31 +1033,25 @@ def _batch_appends(self, items, obj):
write(APPEND)
return

it = iter(items)
start = 0
while True:
tmp = list(islice(it, self._BATCHSIZE))
n = len(tmp)
if n > 1:
for batch in batched(items, self._BATCHSIZE):
if len(batch) != 1:
write(MARK)
for i, x in enumerate(tmp, start):
for i, x in enumerate(batch, start):
try:
save(x)
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} item {i}')
raise
write(APPENDS)
elif n:
else:
try:
save(tmp[0])
save(batch[0])
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} item {start}')
raise
write(APPEND)
# else tmp is empty, and we're done
if n < self._BATCHSIZE:
return
start += n
start += len(batch)

def save_dict(self, obj):
if self.bin:
Expand Down Expand Up @@ -1086,32 +1080,26 @@ def _batch_setitems(self, items, obj):
write(SETITEM)
return

it = iter(items)
while True:
tmp = list(islice(it, self._BATCHSIZE))
n = len(tmp)
if n > 1:
for batch in batched(items, self._BATCHSIZE):
if len(batch) != 1:
write(MARK)
for k, v in tmp:
for k, v in batch:
save(k)
try:
save(v)
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} item {k!r}')
raise
write(SETITEMS)
elif n:
k, v = tmp[0]
else:
k, v = batch[0]
save(k)
try:
save(v)
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} item {k!r}')
raise
write(SETITEM)
# else tmp is empty, and we're done
if n < self._BATCHSIZE:
return

def save_set(self, obj):
save = self.save
Expand All @@ -1124,21 +1112,15 @@ def save_set(self, obj):
write(EMPTY_SET)
self.memoize(obj)

it = iter(obj)
while True:
batch = list(islice(it, self._BATCHSIZE))
n = len(batch)
if n > 0:
write(MARK)
try:
for item in batch:
save(item)
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} element')
raise
write(ADDITEMS)
if n < self._BATCHSIZE:
return
for batch in batched(obj, self._BATCHSIZE):
write(MARK)
try:
for item in batch:
save(item)
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} element')
raise
write(ADDITEMS)
dispatch[set] = save_set

def save_frozenset(self, obj):
Expand Down

0 comments on commit 4fa459b

Please sign in to comment.