Skip to content

Commit

Permalink
pythongh-125631: Enable setting persistent_id and persistent_load of …
Browse files Browse the repository at this point in the history
…pickler and unpickler

pickle.Pickler.persistent_id and pickle.Unpickler.persistent_load can
again be overridden as instance attributes.
  • Loading branch information
serhiy-storchaka committed Oct 20, 2024
1 parent 2e950e3 commit 0f694d4
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 6 deletions.
9 changes: 5 additions & 4 deletions Lib/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,11 @@ def save(self, obj, save_persistent_id=True):
self.framer.commit_frame()

# Check for persistent id (defined by a subclass)
pid = self.persistent_id(obj)
if pid is not None and save_persistent_id:
self.save_pers(pid)
return
if save_persistent_id:
pid = self.persistent_id(obj)
if pid is not None:
self.save_pers(pid)
return

# Check the memo
x = self.memo.get(id(obj))
Expand Down
82 changes: 80 additions & 2 deletions Lib/test/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,84 @@ def persistent_load(subself, pid):
unpickler = PersUnpickler(io.BytesIO(self.dumps('abc', proto)))
self.assertEqual(unpickler.load(), 'abc')

def test_pickler_instance_attribute(self):
def persistent_id(obj):
called.append(obj)
return obj

for proto in range(pickle.HIGHEST_PROTOCOL + 1):
f = io.BytesIO()
pickler = self.pickler(f, proto)
called = []
old_persistent_id = pickler.persistent_id
pickler.persistent_id = persistent_id
self.assertEqual(pickler.persistent_id, persistent_id)
pickler.dump('abc')
self.assertEqual(called, ['abc'])
self.assertEqual(self.loads(f.getvalue()), 'abc')
del pickler.persistent_id
self.assertEqual(pickler.persistent_id, old_persistent_id)

def test_unpickler_instance_attribute(self):
def persistent_load(pid):
called.append(pid)
return pid

for proto in range(pickle.HIGHEST_PROTOCOL + 1):
unpickler = self.unpickler(io.BytesIO(self.dumps('abc', proto)))
called = []
old_persistent_load = unpickler.persistent_load
unpickler.persistent_load = persistent_load
self.assertEqual(unpickler.persistent_load, persistent_load)
self.assertEqual(unpickler.load(), 'abc')
self.assertEqual(called, ['abc'])
del unpickler.persistent_load
self.assertEqual(unpickler.persistent_load, old_persistent_load)

def test_pickler_super_instance_attribute(self):
class PersPickler(self.pickler):
def persistent_id(subself, obj):
raise AssertionError('should never be called')
def _persistent_id(subself, obj):
called.append(obj)
self.assertIsNone(super().persistent_id(obj))
return obj

for proto in range(pickle.HIGHEST_PROTOCOL + 1):
f = io.BytesIO()
pickler = PersPickler(f, proto)
called = []
old_persistent_id = pickler.persistent_id
pickler.persistent_id = pickler._persistent_id
self.assertEqual(pickler.persistent_id, pickler._persistent_id)
pickler.dump('abc')
self.assertEqual(called, ['abc'])
self.assertEqual(self.loads(f.getvalue()), 'abc')
del pickler.persistent_id
self.assertEqual(pickler.persistent_id, old_persistent_id)

def test_unpickler_super_instance_attribute(self):
class PersUnpickler(self.unpickler):
def persistent_load(subself, pid):
raise AssertionError('should never be called')
def _persistent_load(subself, pid):
called.append(pid)
with self.assertRaises(self.persistent_load_error):
super().persistent_load(pid)
return pid

for proto in range(pickle.HIGHEST_PROTOCOL + 1):
unpickler = PersUnpickler(io.BytesIO(self.dumps('abc', proto)))
called = []
old_persistent_load = unpickler.persistent_load
unpickler.persistent_load = unpickler._persistent_load
self.assertEqual(unpickler.persistent_load, unpickler._persistent_load)
self.assertEqual(unpickler.load(), 'abc')
self.assertEqual(called, ['abc'])
del unpickler.persistent_load
self.assertEqual(unpickler.persistent_load, old_persistent_load)


class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase):

pickler_class = pickle._Pickler
Expand Down Expand Up @@ -367,7 +445,7 @@ class SizeofTests(unittest.TestCase):
check_sizeof = support.check_sizeof

def test_pickler(self):
basesize = support.calcobjsize('6P2n3i2n3i2P')
basesize = support.calcobjsize('7P2n3i2n3i2P')
p = _pickle.Pickler(io.BytesIO())
self.assertEqual(object.__sizeof__(p), basesize)
MT_size = struct.calcsize('3nP0n')
Expand All @@ -384,7 +462,7 @@ def test_pickler(self):
0) # Write buffer is cleared after every dump().

def test_unpickler(self):
basesize = support.calcobjsize('2P2nP 2P2n2i5P 2P3n8P2n2i')
basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n8P2n2i')
unpickler = _pickle.Unpickler
P = struct.calcsize('P') # Size of memo table entry.
n = struct.calcsize('n') # Size of mark table entry.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Restore ability to set :attr:`~pickle.Pickler.persistent_id` and
:attr:`~pickle.Unpickler.persistent_load` attributes of instances of the
:class:`!Pickler` and :class:`!Unpickler` classes in the :mod:`pickle`
module.
62 changes: 62 additions & 0 deletions Modules/_pickle.c
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ typedef struct PicklerObject {
objects to support self-referential objects
pickling. */
PyObject *persistent_id; /* persistent_id() method, can be NULL */
PyObject *persistent_id_attr; /* instance attribute, can be NULL */
PyObject *dispatch_table; /* private dispatch_table, can be NULL */
PyObject *reducer_override; /* hook for invoking user-defined callbacks
instead of save_global when pickling
Expand Down Expand Up @@ -655,6 +656,7 @@ typedef struct UnpicklerObject {
size_t memo_len; /* Number of objects in the memo */

PyObject *persistent_load; /* persistent_load() method, can be NULL. */
PyObject *persistent_load_attr; /* instance attribute, can be NULL. */

Py_buffer buffer;
char *input_buffer;
Expand Down Expand Up @@ -1108,6 +1110,7 @@ _Pickler_New(PickleState *st)

self->memo = memo;
self->persistent_id = NULL;
self->persistent_id_attr = NULL;
self->dispatch_table = NULL;
self->reducer_override = NULL;
self->write = NULL;
Expand Down Expand Up @@ -1602,6 +1605,7 @@ _Unpickler_New(PyObject *module)
self->memo_size = MEMO_SIZE;
self->memo_len = 0;
self->persistent_load = NULL;
self->persistent_load_attr = NULL;
memset(&self->buffer, 0, sizeof(Py_buffer));
self->input_buffer = NULL;
self->input_line = NULL;
Expand Down Expand Up @@ -5088,6 +5092,33 @@ Pickler_set_memo(PicklerObject *self, PyObject *obj, void *Py_UNUSED(ignored))
return -1;
}

static PyObject *
Pickler_getattr(PyObject *self, PyObject *name)
{
if (PyUnicode_Check(name)
&& PyUnicode_CompareWithASCIIString(name, "persistent_id") == 0
&& ((PicklerObject *)self)->persistent_id_attr)
{
return Py_NewRef(((PicklerObject *)self)->persistent_id_attr);
}

return PyObject_GenericGetAttr(self, name);
}

static int
Pickler_setattr(PyObject *self, PyObject *name, PyObject *value)
{
if (PyUnicode_Check(name)
&& PyUnicode_CompareWithASCIIString(name, "persistent_id") == 0)
{
Py_XINCREF(value);
Py_XSETREF(((PicklerObject *)self)->persistent_id_attr, value);
return 0;
}

return PyObject_GenericSetAttr(self, name, value);
}

static PyMemberDef Pickler_members[] = {
{"bin", Py_T_INT, offsetof(PicklerObject, bin)},
{"fast", Py_T_INT, offsetof(PicklerObject, fast)},
Expand All @@ -5103,6 +5134,8 @@ static PyGetSetDef Pickler_getsets[] = {

static PyType_Slot pickler_type_slots[] = {
{Py_tp_dealloc, Pickler_dealloc},
{Py_tp_getattro, Pickler_getattr},
{Py_tp_setattro, Pickler_setattr},
{Py_tp_methods, Pickler_methods},
{Py_tp_members, Pickler_members},
{Py_tp_getset, Pickler_getsets},
Expand Down Expand Up @@ -7562,6 +7595,33 @@ Unpickler_set_memo(UnpicklerObject *self, PyObject *obj, void *Py_UNUSED(ignored
return -1;
}

static PyObject *
Unpickler_getattr(PyObject *self, PyObject *name)
{
if (PyUnicode_Check(name)
&& PyUnicode_CompareWithASCIIString(name, "persistent_load") == 0
&& ((UnpicklerObject *)self)->persistent_load_attr)
{
return Py_NewRef(((UnpicklerObject *)self)->persistent_load_attr);
}

return PyObject_GenericGetAttr(self, name);
}

static int
Unpickler_setattr(PyObject *self, PyObject *name, PyObject *value)
{
if (PyUnicode_Check(name)
&& PyUnicode_CompareWithASCIIString(name, "persistent_load") == 0)
{
Py_XINCREF(value);
Py_XSETREF(((UnpicklerObject *)self)->persistent_load_attr, value);
return 0;
}

return PyObject_GenericSetAttr(self, name, value);
}

static PyGetSetDef Unpickler_getsets[] = {
{"memo", (getter)Unpickler_get_memo, (setter)Unpickler_set_memo},
{NULL}
Expand All @@ -7570,6 +7630,8 @@ static PyGetSetDef Unpickler_getsets[] = {
static PyType_Slot unpickler_type_slots[] = {
{Py_tp_dealloc, Unpickler_dealloc},
{Py_tp_doc, (char *)_pickle_Unpickler___init____doc__},
{Py_tp_getattro, Unpickler_getattr},
{Py_tp_setattro, Unpickler_setattr},
{Py_tp_traverse, Unpickler_traverse},
{Py_tp_clear, Unpickler_clear},
{Py_tp_methods, Unpickler_methods},
Expand Down

0 comments on commit 0f694d4

Please sign in to comment.