diff --git a/Lib/pickle.py b/Lib/pickle.py index ed8138beb908ee..3c1a6c1de20acb 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -548,10 +548,11 @@ def save(self, obj, save_persistent_id=True): self.framer.commit_frame() # Check for persistent id (defined by a subclass) - pid = self.persistent_id(obj) - if pid is not None and save_persistent_id: - self.save_pers(pid) - return + if save_persistent_id: + pid = self.persistent_id(obj) + if pid is not None: + self.save_pers(pid) + return # Check the memo x = self.memo.get(id(obj)) diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py index c84e507cdf645f..6e55b18d527679 100644 --- a/Lib/test/test_pickle.py +++ b/Lib/test/test_pickle.py @@ -244,6 +244,84 @@ def persistent_load(subself, pid): unpickler = PersUnpickler(io.BytesIO(self.dumps('abc', proto))) self.assertEqual(unpickler.load(), 'abc') + def test_pickler_instance_attribute(self): + def persistent_id(obj): + called.append(obj) + return obj + + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + f = io.BytesIO() + pickler = self.pickler(f, proto) + called = [] + old_persistent_id = pickler.persistent_id + pickler.persistent_id = persistent_id + self.assertEqual(pickler.persistent_id, persistent_id) + pickler.dump('abc') + self.assertEqual(called, ['abc']) + self.assertEqual(self.loads(f.getvalue()), 'abc') + del pickler.persistent_id + self.assertEqual(pickler.persistent_id, old_persistent_id) + + def test_unpickler_instance_attribute(self): + def persistent_load(pid): + called.append(pid) + return pid + + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + unpickler = self.unpickler(io.BytesIO(self.dumps('abc', proto))) + called = [] + old_persistent_load = unpickler.persistent_load + unpickler.persistent_load = persistent_load + self.assertEqual(unpickler.persistent_load, persistent_load) + self.assertEqual(unpickler.load(), 'abc') + self.assertEqual(called, ['abc']) + del unpickler.persistent_load + self.assertEqual(unpickler.persistent_load, old_persistent_load) + + def test_pickler_super_instance_attribute(self): + class PersPickler(self.pickler): + def persistent_id(subself, obj): + raise AssertionError('should never be called') + def _persistent_id(subself, obj): + called.append(obj) + self.assertIsNone(super().persistent_id(obj)) + return obj + + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + f = io.BytesIO() + pickler = PersPickler(f, proto) + called = [] + old_persistent_id = pickler.persistent_id + pickler.persistent_id = pickler._persistent_id + self.assertEqual(pickler.persistent_id, pickler._persistent_id) + pickler.dump('abc') + self.assertEqual(called, ['abc']) + self.assertEqual(self.loads(f.getvalue()), 'abc') + del pickler.persistent_id + self.assertEqual(pickler.persistent_id, old_persistent_id) + + def test_unpickler_super_instance_attribute(self): + class PersUnpickler(self.unpickler): + def persistent_load(subself, pid): + raise AssertionError('should never be called') + def _persistent_load(subself, pid): + called.append(pid) + with self.assertRaises(self.persistent_load_error): + super().persistent_load(pid) + return pid + + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + unpickler = PersUnpickler(io.BytesIO(self.dumps('abc', proto))) + called = [] + old_persistent_load = unpickler.persistent_load + unpickler.persistent_load = unpickler._persistent_load + self.assertEqual(unpickler.persistent_load, unpickler._persistent_load) + self.assertEqual(unpickler.load(), 'abc') + self.assertEqual(called, ['abc']) + del unpickler.persistent_load + self.assertEqual(unpickler.persistent_load, old_persistent_load) + + class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase): pickler_class = pickle._Pickler @@ -367,7 +445,7 @@ class SizeofTests(unittest.TestCase): check_sizeof = support.check_sizeof def test_pickler(self): - basesize = support.calcobjsize('6P2n3i2n3i2P') + basesize = support.calcobjsize('7P2n3i2n3i2P') p = _pickle.Pickler(io.BytesIO()) self.assertEqual(object.__sizeof__(p), basesize) MT_size = struct.calcsize('3nP0n') @@ -384,7 +462,7 @@ def test_pickler(self): 0) # Write buffer is cleared after every dump(). def test_unpickler(self): - basesize = support.calcobjsize('2P2nP 2P2n2i5P 2P3n8P2n2i') + basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n8P2n2i') unpickler = _pickle.Unpickler P = struct.calcsize('P') # Size of memo table entry. n = struct.calcsize('n') # Size of mark table entry. diff --git a/Misc/NEWS.d/next/Library/2024-10-19-11-06-06.gh-issue-125631.BlhVvR.rst b/Misc/NEWS.d/next/Library/2024-10-19-11-06-06.gh-issue-125631.BlhVvR.rst new file mode 100644 index 00000000000000..e870abbf87803a --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-10-19-11-06-06.gh-issue-125631.BlhVvR.rst @@ -0,0 +1,4 @@ +Restore ability to set :attr:`~pickle.Pickler.persistent_id` and +:attr:`~pickle.Unpickler.persistent_load` attributes of instances of the +:class:`!Pickler` and :class:`!Unpickler` classes in the :mod:`pickle` +module. diff --git a/Modules/_pickle.c b/Modules/_pickle.c index b2bd9545c1b130..7ea82c16e57d59 100644 --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -613,6 +613,7 @@ typedef struct PicklerObject { objects to support self-referential objects pickling. */ PyObject *persistent_id; /* persistent_id() method, can be NULL */ + PyObject *persistent_id_attr; /* instance attribute, can be NULL */ PyObject *dispatch_table; /* private dispatch_table, can be NULL */ PyObject *reducer_override; /* hook for invoking user-defined callbacks instead of save_global when pickling @@ -655,6 +656,7 @@ typedef struct UnpicklerObject { size_t memo_len; /* Number of objects in the memo */ PyObject *persistent_load; /* persistent_load() method, can be NULL. */ + PyObject *persistent_load_attr; /* instance attribute, can be NULL. */ Py_buffer buffer; char *input_buffer; @@ -1108,6 +1110,7 @@ _Pickler_New(PickleState *st) self->memo = memo; self->persistent_id = NULL; + self->persistent_id_attr = NULL; self->dispatch_table = NULL; self->reducer_override = NULL; self->write = NULL; @@ -1602,6 +1605,7 @@ _Unpickler_New(PyObject *module) self->memo_size = MEMO_SIZE; self->memo_len = 0; self->persistent_load = NULL; + self->persistent_load_attr = NULL; memset(&self->buffer, 0, sizeof(Py_buffer)); self->input_buffer = NULL; self->input_line = NULL; @@ -5088,6 +5092,33 @@ Pickler_set_memo(PicklerObject *self, PyObject *obj, void *Py_UNUSED(ignored)) return -1; } +static PyObject * +Pickler_getattr(PyObject *self, PyObject *name) +{ + if (PyUnicode_Check(name) + && PyUnicode_CompareWithASCIIString(name, "persistent_id") == 0 + && ((PicklerObject *)self)->persistent_id_attr) + { + return Py_NewRef(((PicklerObject *)self)->persistent_id_attr); + } + + return PyObject_GenericGetAttr(self, name); +} + +static int +Pickler_setattr(PyObject *self, PyObject *name, PyObject *value) +{ + if (PyUnicode_Check(name) + && PyUnicode_CompareWithASCIIString(name, "persistent_id") == 0) + { + Py_XINCREF(value); + Py_XSETREF(((PicklerObject *)self)->persistent_id_attr, value); + return 0; + } + + return PyObject_GenericSetAttr(self, name, value); +} + static PyMemberDef Pickler_members[] = { {"bin", Py_T_INT, offsetof(PicklerObject, bin)}, {"fast", Py_T_INT, offsetof(PicklerObject, fast)}, @@ -5103,6 +5134,8 @@ static PyGetSetDef Pickler_getsets[] = { static PyType_Slot pickler_type_slots[] = { {Py_tp_dealloc, Pickler_dealloc}, + {Py_tp_getattro, Pickler_getattr}, + {Py_tp_setattro, Pickler_setattr}, {Py_tp_methods, Pickler_methods}, {Py_tp_members, Pickler_members}, {Py_tp_getset, Pickler_getsets}, @@ -7562,6 +7595,33 @@ Unpickler_set_memo(UnpicklerObject *self, PyObject *obj, void *Py_UNUSED(ignored return -1; } +static PyObject * +Unpickler_getattr(PyObject *self, PyObject *name) +{ + if (PyUnicode_Check(name) + && PyUnicode_CompareWithASCIIString(name, "persistent_load") == 0 + && ((UnpicklerObject *)self)->persistent_load_attr) + { + return Py_NewRef(((UnpicklerObject *)self)->persistent_load_attr); + } + + return PyObject_GenericGetAttr(self, name); +} + +static int +Unpickler_setattr(PyObject *self, PyObject *name, PyObject *value) +{ + if (PyUnicode_Check(name) + && PyUnicode_CompareWithASCIIString(name, "persistent_load") == 0) + { + Py_XINCREF(value); + Py_XSETREF(((UnpicklerObject *)self)->persistent_load_attr, value); + return 0; + } + + return PyObject_GenericSetAttr(self, name, value); +} + static PyGetSetDef Unpickler_getsets[] = { {"memo", (getter)Unpickler_get_memo, (setter)Unpickler_set_memo}, {NULL} @@ -7570,6 +7630,8 @@ static PyGetSetDef Unpickler_getsets[] = { static PyType_Slot unpickler_type_slots[] = { {Py_tp_dealloc, Unpickler_dealloc}, {Py_tp_doc, (char *)_pickle_Unpickler___init____doc__}, + {Py_tp_getattro, Unpickler_getattr}, + {Py_tp_setattro, Unpickler_setattr}, {Py_tp_traverse, Unpickler_traverse}, {Py_tp_clear, Unpickler_clear}, {Py_tp_methods, Unpickler_methods},