diff --git a/AUTHORS.rst b/AUTHORS.rst index 4954da7b6..0d32aef66 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -80,6 +80,7 @@ Authors - Klaas van Schelven - Kris Neuharth - Kyle Seever (`kseever `_) +- Léni Gauffier (`legau `_) - Leticia Portella - Lucas Wiman - Maciej "RooTer" Urbański diff --git a/docs/historical_model.rst b/docs/historical_model.rst index 2b0bc78ce..741781e7d 100644 --- a/docs/historical_model.rst +++ b/docs/historical_model.rst @@ -461,11 +461,15 @@ If you want to track many to many relationships, you need to define them explici class Poll(models.Model): question = models.CharField(max_length=200) categories = models.ManyToManyField(Category) - history = HistoricalRecords(many_to_many=[categories]) + history = HistoricalRecords(m2m_fields=[categories]) This will create a historical intermediate model that tracks each relational change between `Poll` and `Category`. +You may also define these fields in a model attribute (by default on `_history_m2m_fields`). +This is mainly used for inherited models. You can override the attribute name by setting +your own `m2m_fields_model_field_name` argument on the `HistoricalRecord` instance. + You will see the many to many changes when diffing between two historical records: .. code-block:: python diff --git a/docs/signals.rst b/docs/signals.rst index ee3cc6608..5f168dea9 100644 --- a/docs/signals.rst +++ b/docs/signals.rst @@ -22,6 +22,24 @@ saving a historical record. Arguments passed to the signals include the followin using The database alias being used +For Many To Many signals you've got the following : + +.. glossary:: + instance + The source model instance being saved + + history_instance + The corresponding history record + + rows (for pre_create) + The elements to be bulk inserted into the m2m table + + created_rows (for post_create) + The created elements into the m2m table + + field + The recorded field object + To connect the signals to your callbacks, you can use the ``@receiver`` decorator: .. code-block:: python @@ -30,6 +48,8 @@ To connect the signals to your callbacks, you can use the ``@receiver`` decorato from simple_history.signals import ( pre_create_historical_record, post_create_historical_record + pre_create_historical_m2m_records, + post_create_historical_m2m_records, ) @receiver(pre_create_historical_record) @@ -39,3 +59,11 @@ To connect the signals to your callbacks, you can use the ``@receiver`` decorato @receiver(post_create_historical_record) def post_create_historical_record_callback(sender, **kwargs): print("Sent after saving historical record") + + @receiver(pre_create_historical_m2m_records) + def pre_create_historical_m2m_records_callback(sender, **kwargs): + print("Sent before saving many to many field on historical record") + + @receiver(post_create_historical_m2m_records) + def post_create_historical_m2m_records_callback(sender, **kwargs): + print("Sent after saving many to many field on historical record") diff --git a/simple_history/models.py b/simple_history/models.py index 38e9b0814..0a4ae6dcf 100644 --- a/simple_history/models.py +++ b/simple_history/models.py @@ -32,7 +32,12 @@ from . import exceptions from .manager import SIMPLE_HISTORY_REVERSE_ATTR_NAME, HistoryDescriptor -from .signals import post_create_historical_record, pre_create_historical_record +from .signals import ( + post_create_historical_m2m_records, + post_create_historical_record, + pre_create_historical_m2m_records, + pre_create_historical_record, +) from .utils import get_change_reason_from_object try: @@ -94,6 +99,8 @@ def __init__( no_db_index=list(), excluded_field_kwargs=None, m2m_fields=(), + m2m_fields_model_field_name="_history_m2m_fields", + m2m_bases=(models.Model,), ): self.user_set_verbose_name = verbose_name self.user_set_verbose_name_plural = verbose_name_plural @@ -114,6 +121,7 @@ def __init__( self.related_name = related_name self.use_base_model_db = use_base_model_db self.m2m_fields = m2m_fields + self.m2m_fields_model_field_name = m2m_fields_model_field_name if isinstance(no_db_index, str): no_db_index = [no_db_index] @@ -132,6 +140,12 @@ def __init__( self.bases = (HistoricalChanges,) + tuple(bases) except TypeError: raise TypeError("The `bases` option must be a list or a tuple.") + try: + if isinstance(m2m_bases, str): + raise TypeError + self.m2m_bases = (HistoricalChanges,) + tuple(m2m_bases) + except TypeError: + raise TypeError("The `m2m_bases` option must be a list or a tuple.") def contribute_to_class(self, cls, name): self.manager_name = name @@ -189,7 +203,10 @@ def finalize(self, sender, **kwargs): # so the signal handlers can't use weak references. models.signals.post_save.connect(self.post_save, sender=sender, weak=False) models.signals.post_delete.connect(self.post_delete, sender=sender, weak=False) - for field in self.m2m_fields: + + m2m_fields = self.get_m2m_fields_from_model(sender) + + for field in m2m_fields: m2m_changed.connect( partial(self.m2m_changed, attr=field.name), sender=field.remote_field.through, @@ -200,13 +217,12 @@ def finalize(self, sender, **kwargs): setattr(sender, self.manager_name, descriptor) sender._meta.simple_history_manager_attribute = self.manager_name - for field in self.m2m_fields: + for field in m2m_fields: m2m_model = self.create_history_m2m_model( history_model, field.remote_field.through ) self.m2m_models[field] = m2m_model - module = importlib.import_module(self.module) setattr(module, m2m_model.__name__, m2m_model) m2m_descriptor = HistoryDescriptor(m2m_model) @@ -235,46 +251,18 @@ def get_history_model_name(self, model): ) def create_history_m2m_model(self, model, through_model): - attrs = { - "__module__": self.module, - "__str__": lambda self: "{} as of {}".format( - self._meta.verbose_name, self.history.history_date - ), - } - - app_module = "%s.models" % model._meta.app_label - - if model.__module__ != self.module: - # registered under different app - attrs["__module__"] = self.module - elif app_module != self.module: - # Abuse an internal API because the app registry is loading. - app = apps.app_configs[model._meta.app_label] - models_module = app.name - attrs["__module__"] = models_module - - # Get the primary key to the history model this model will look up to - attrs["m2m_history_id"] = self._get_history_id_field() - attrs["history"] = models.ForeignKey( - model, - db_constraint=False, - on_delete=models.DO_NOTHING, - ) - attrs["instance_type"] = through_model + attrs = {} fields = self.copy_fields(through_model) attrs.update(fields) + attrs.update(self.get_extra_fields_m2m(model, through_model, fields)) name = self.get_history_model_name(through_model) registered_models[through_model._meta.db_table] = through_model - meta_fields = {"verbose_name": name} - if self.app: - meta_fields["app_label"] = self.app - - attrs.update(Meta=type(str("Meta"), (), meta_fields)) + attrs.update(Meta=type("Meta", (), self.get_meta_options_m2m(through_model))) - m2m_history_model = type(str(name), (models.Model,), attrs) + m2m_history_model = type(str(name), self.m2m_bases, attrs) return m2m_history_model @@ -285,7 +273,7 @@ def create_history_model(self, model, inherited): attrs = { "__module__": self.module, "_history_excluded_fields": self.excluded_fields, - "_history_m2m_fields": self.m2m_fields, + "_history_m2m_fields": self.get_m2m_fields_from_model(model), } app_module = "%s.models" % model._meta.app_label @@ -412,7 +400,7 @@ def _get_history_change_reason_field(self): def _get_history_id_field(self): if self.history_id_field: - history_id_field = self.history_id_field + history_id_field = self.history_id_field.clone() history_id_field.primary_key = True history_id_field.editable = False elif getattr(settings, "SIMPLE_HISTORY_HISTORY_ID_USE_UUID", False): @@ -465,6 +453,25 @@ def _get_history_related_field(self, model): else: return {} + def get_extra_fields_m2m(self, model, through_model, fields): + """Return dict of extra fields added to the m2m historical record model""" + + extra_fields = { + "__module__": model.__module__, + "__str__": lambda self: "{} as of {}".format( + self._meta.verbose_name, self.history.history_date + ), + "history": models.ForeignKey( + model, + db_constraint=False, + on_delete=models.DO_NOTHING, + ), + "instance_type": through_model, + "m2m_history_id": self._get_history_id_field(), + } + + return extra_fields + def get_extra_fields(self, model, fields): """Return dict of extra fields added to the historical record model""" @@ -577,6 +584,20 @@ def _date_indexing(self): ) return result + def get_meta_options_m2m(self, through_model): + """ + Returns a dictionary of fields that will be added to + the Meta inner class of the m2m historical record model. + """ + name = self.get_history_model_name(through_model) + + meta_fields = {"verbose_name": name} + + if self.app: + meta_fields["app_label"] = self.app + + return meta_fields + def get_meta_options(self, model): """ Returns a dictionary of fields that will be added to @@ -637,7 +658,7 @@ def m2m_changed(self, instance, action, attr, pk_set, reverse, **_): self.create_historical_record(instance, "~") def create_historical_record_m2ms(self, history_instance, instance): - for field in self.m2m_fields: + for field in history_instance._history_m2m_fields: m2m_history_model = self.m2m_models[field] original_instance = history_instance.instance through_model = getattr(original_instance, field.name).through @@ -657,7 +678,21 @@ def create_historical_record_m2ms(self, history_instance, instance): ) insert_rows.append(m2m_history_model(**insert_row)) - m2m_history_model.objects.bulk_create(insert_rows) + pre_create_historical_m2m_records.send( + sender=m2m_history_model, + rows=insert_rows, + history_instance=history_instance, + instance=instance, + field=field, + ) + created_rows = m2m_history_model.objects.bulk_create(insert_rows) + post_create_historical_m2m_records.send( + sender=m2m_history_model, + created_rows=created_rows, + history_instance=history_instance, + instance=instance, + field=field, + ) def create_historical_record(self, instance, history_type, using=None): using = using if self.use_base_model_db else None @@ -721,6 +756,14 @@ def get_history_user(self, instance): return self.get_user(instance=instance, request=request) + def get_m2m_fields_from_model(self, model): + m2m_fields = set(self.m2m_fields) + try: + m2m_fields.update(getattr(model, self.m2m_fields_model_field_name)) + except AttributeError: + pass + return [getattr(model, field.name).field for field in m2m_fields] + def transform_field(field): """Customize field appropriately for use in historical model""" @@ -880,12 +923,20 @@ def diff_against(self, old_history, excluded_fields=None, included_fields=None): if excluded_fields is None: excluded_fields = set() + included_m2m_fields = {field.name for field in old_history._history_m2m_fields} if included_fields is None: included_fields = { f.name for f in old_history.instance_type._meta.fields if f.editable } + else: + included_m2m_fields = included_m2m_fields.intersection(included_fields) - fields = set(included_fields).difference(excluded_fields) + fields = ( + set(included_fields) + .difference(included_m2m_fields) + .difference(excluded_fields) + ) + m2m_fields = set(included_m2m_fields).difference(excluded_fields) changes = [] changed_fields = [] @@ -902,11 +953,10 @@ def diff_against(self, old_history, excluded_fields=None, included_fields=None): changed_fields.append(field) # Separately compare m2m fields: - for field in old_history._history_m2m_fields: + for field in m2m_fields: # First retrieve a single item to get the field names from: reference_history_m2m_item = ( - getattr(old_history, field.name).first() - or getattr(self, field.name).first() + getattr(old_history, field).first() or getattr(self, field).first() ) history_field_names = [] if reference_history_m2m_item: @@ -920,15 +970,13 @@ def diff_against(self, old_history, excluded_fields=None, included_fields=None): if f.editable and f.name not in ["id", "m2m_history_id", "history"] ] - old_rows = list( - getattr(old_history, field.name).values(*history_field_names) - ) - new_rows = list(getattr(self, field.name).values(*history_field_names)) + old_rows = list(getattr(old_history, field).values(*history_field_names)) + new_rows = list(getattr(self, field).values(*history_field_names)) if old_rows != new_rows: - change = ModelChange(field.name, old_rows, new_rows) + change = ModelChange(field, old_rows, new_rows) changes.append(change) - changed_fields.append(field.name) + changed_fields.append(field) return ModelDelta(changes, changed_fields, old_history, self) diff --git a/simple_history/signals.py b/simple_history/signals.py index 090008089..270dc3843 100644 --- a/simple_history/signals.py +++ b/simple_history/signals.py @@ -7,3 +7,11 @@ # Arguments: "instance", "history_instance", "history_date", # "history_user", "history_change_reason", "using" post_create_historical_record = django.dispatch.Signal() + +# Arguments: "sender", "rows", "history_instance", "instance", +# "field" +pre_create_historical_m2m_records = django.dispatch.Signal() + +# Arguments: "sender", "created_rows", "history_instance", +# "instance", "field" +post_create_historical_m2m_records = django.dispatch.Signal() diff --git a/simple_history/tests/models.py b/simple_history/tests/models.py index 0dab6c7a9..681c20731 100644 --- a/simple_history/tests/models.py +++ b/simple_history/tests/models.py @@ -117,6 +117,37 @@ class PollWithManyToMany(models.Model): history = HistoricalRecords(m2m_fields=[places]) +class PollWithManyToManyCustomHistoryID(models.Model): + question = models.CharField(max_length=200) + pub_date = models.DateTimeField("date published") + places = models.ManyToManyField("Place") + + history = HistoricalRecords( + m2m_fields=[places], history_id_field=models.UUIDField(default=uuid.uuid4) + ) + + +class HistoricalRecordsWithExtraFieldM2M(HistoricalRecords): + def get_extra_fields_m2m(self, model, through_model, fields): + extra_fields = super().get_extra_fields_m2m(model, through_model, fields) + + def get_class_name(self): + return self.__class__.__name__ + + extra_fields["get_class_name"] = get_class_name + return extra_fields + + +class PollWithManyToManyWithIPAddress(models.Model): + question = models.CharField(max_length=200) + pub_date = models.DateTimeField("date published") + places = models.ManyToManyField("Place") + + history = HistoricalRecordsWithExtraFieldM2M( + m2m_fields=[places], m2m_bases=[IPAddressHistoricalModel] + ) + + class PollWithSeveralManyToMany(models.Model): question = models.CharField(max_length=200) pub_date = models.DateTimeField("date published") @@ -127,6 +158,32 @@ class PollWithSeveralManyToMany(models.Model): history = HistoricalRecords(m2m_fields=[places, restaurants, books]) +class PollParentWithManyToMany(models.Model): + question = models.CharField(max_length=200) + pub_date = models.DateTimeField("date published") + places = models.ManyToManyField("Place") + + history = HistoricalRecords( + m2m_fields=[places], + inherit=True, + ) + + class Meta: + abstract = True + + +class PollChildBookWithManyToMany(PollParentWithManyToMany): + books = models.ManyToManyField("Book", related_name="books_poll_child") + _history_m2m_fields = [books] + + +class PollChildRestaurantWithManyToMany(PollParentWithManyToMany): + restaurants = models.ManyToManyField( + "Restaurant", related_name="restaurants_poll_child" + ) + _history_m2m_fields = [restaurants] + + class CustomAttrNameForeignKey(models.ForeignKey): def __init__(self, *args, **kwargs): self.attr_name = kwargs.pop("attr_name", None) diff --git a/simple_history/tests/tests/test_models.py b/simple_history/tests/tests/test_models.py index c53507d32..206410f8f 100644 --- a/simple_history/tests/tests/test_models.py +++ b/simple_history/tests/tests/test_models.py @@ -24,7 +24,10 @@ is_historic, to_historic, ) -from simple_history.signals import pre_create_historical_record +from simple_history.signals import ( + pre_create_historical_m2m_records, + pre_create_historical_record, +) from simple_history.tests.custom_user.models import CustomUser from simple_history.tests.tests.utils import ( database_router_override_settings, @@ -87,12 +90,16 @@ Person, Place, Poll, + PollChildBookWithManyToMany, + PollChildRestaurantWithManyToMany, PollInfo, PollWithExcludedFieldsWithDefaults, PollWithExcludedFKField, PollWithExcludeFields, PollWithHistoricalIPAddress, PollWithManyToMany, + PollWithManyToManyCustomHistoryID, + PollWithManyToManyWithIPAddress, PollWithNonEditableField, PollWithSeveralManyToMany, Province, @@ -1485,6 +1492,11 @@ def add_static_history_ip_address(sender, **kwargs): history_instance.ip_address = "192.168.0.1" +def add_static_history_ip_address_on_m2m(sender, rows, **kwargs): + for row in rows: + row.ip_address = "192.168.0.1" + + class ExtraFieldsStaticIPAddressTestCase(TestCase): def setUp(self): pre_create_historical_record.connect( @@ -1736,6 +1748,115 @@ def test_separation(self): self.assertEqual(add.places.all().count(), 0) +class InheritedManyToManyTest(TestCase): + def setUp(self): + self.model_book = PollChildBookWithManyToMany + self.model_rstr = PollChildRestaurantWithManyToMany + self.place = Place.objects.create(name="Home") + self.book = Book.objects.create(isbn="1234") + self.restaurant = Restaurant.objects.create(rating=1) + self.poll_book = self.model_book.objects.create( + question="what's up?", pub_date=today + ) + self.poll_rstr = self.model_rstr.objects.create( + question="what's up?", pub_date=today + ) + + def test_separation(self): + self.assertEqual(self.poll_book.history.all().count(), 1) + self.poll_book.places.add(self.place) + self.poll_book.books.add(self.book) + self.assertEqual(self.poll_book.history.all().count(), 3) + + self.assertEqual(self.poll_rstr.history.all().count(), 1) + self.poll_rstr.places.add(self.place) + self.poll_rstr.restaurants.add(self.restaurant) + self.assertEqual(self.poll_rstr.history.all().count(), 3) + + book, place, add = self.poll_book.history.all() + + self.assertEqual(book.books.all().count(), 1) + self.assertEqual(book.places.all().count(), 1) + self.assertEqual(book.books.first().book, self.book) + + self.assertEqual(place.books.all().count(), 0) + self.assertEqual(place.places.all().count(), 1) + self.assertEqual(place.places.first().place, self.place) + + self.assertEqual(add.books.all().count(), 0) + self.assertEqual(add.places.all().count(), 0) + + restaurant, place, add = self.poll_rstr.history.all() + + self.assertEqual(restaurant.restaurants.all().count(), 1) + self.assertEqual(restaurant.places.all().count(), 1) + self.assertEqual(restaurant.restaurants.first().restaurant, self.restaurant) + + self.assertEqual(place.restaurants.all().count(), 0) + self.assertEqual(place.places.all().count(), 1) + self.assertEqual(place.places.first().place, self.place) + + self.assertEqual(add.restaurants.all().count(), 0) + self.assertEqual(add.places.all().count(), 0) + + +class ManyToManyWithSignalsTest(TestCase): + def setUp(self): + self.model = PollWithManyToManyWithIPAddress + # self.historical_through_model = self.model.history. + self.places = ( + Place.objects.create(name="London"), + Place.objects.create(name="Paris"), + ) + self.poll = self.model.objects.create(question="what's up?", pub_date=today) + pre_create_historical_m2m_records.connect( + add_static_history_ip_address_on_m2m, + dispatch_uid="add_static_history_ip_address_on_m2m", + ) + + def tearDown(self): + pre_create_historical_m2m_records.disconnect( + add_static_history_ip_address_on_m2m, + dispatch_uid="add_static_history_ip_address_on_m2m", + ) + + def test_ip_address_added(self): + self.poll.places.add(*self.places) + + places = self.poll.history.first().places + self.assertEqual(2, places.count()) + for place in places.all(): + self.assertEqual("192.168.0.1", place.ip_address) + + def test_extra_field(self): + self.poll.places.add(*self.places) + m2m_record = self.poll.history.first().places.first() + self.assertEqual( + m2m_record.get_class_name(), + "HistoricalPollWithManyToManyWithIPAddress_places", + ) + + def test_diff(self): + self.poll.places.clear() + self.poll.places.add(*self.places) + + new = self.poll.history.first() + old = new.prev_record + + delta = new.diff_against(old) + + self.assertEqual("places", delta.changes[0].field) + self.assertEqual(2, len(delta.changes[0].new)) + + +class ManyToManyCustomIDTest(TestCase): + def setUp(self): + self.model = PollWithManyToManyCustomHistoryID + self.history_model = self.model.history.model + self.place = Place.objects.create(name="Home") + self.poll = self.model.objects.create(question="what's up?", pub_date=today) + + class ManyToManyTest(TestCase): def setUp(self): self.model = PollWithManyToMany @@ -2014,6 +2135,17 @@ def test_diff_against(self): self.assertListEqual(expected_change.new, delta.changes[0].new) self.assertListEqual(expected_change.old, delta.changes[0].old) + delta = add_record.diff_against(create_record, included_fields=["places"]) + self.assertEqual(delta.changed_fields, ["places"]) + self.assertEqual(delta.old_record, create_record) + self.assertEqual(delta.new_record, add_record) + self.assertEqual(expected_change.field, delta.changes[0].field) + + delta = add_record.diff_against(create_record, excluded_fields=["places"]) + self.assertEqual(delta.changed_fields, []) + self.assertEqual(delta.old_record, create_record) + self.assertEqual(delta.new_record, add_record) + self.poll.places.clear() # First and third records are effectively the same. diff --git a/simple_history/tests/tests/test_signals.py b/simple_history/tests/tests/test_signals.py index fc7a0f050..fe0b9c909 100644 --- a/simple_history/tests/tests/test_signals.py +++ b/simple_history/tests/tests/test_signals.py @@ -3,11 +3,13 @@ from django.test import TestCase from simple_history.signals import ( + post_create_historical_m2m_records, post_create_historical_record, + pre_create_historical_m2m_records, pre_create_historical_record, ) -from ..models import Poll +from ..models import Place, Poll, PollWithManyToMany today = datetime(2021, 1, 1, 10, 0) @@ -18,6 +20,8 @@ def setUp(self): self.signal_instance = None self.signal_history_instance = None self.signal_sender = None + self.field = None + self.rows = None def test_pre_create_historical_record_signal(self): def handler(sender, instance, **kwargs): @@ -52,3 +56,59 @@ def handler(sender, instance, history_instance, **kwargs): self.assertEqual(self.signal_instance, p) self.assertIsNotNone(self.signal_history_instance) self.assertEqual(self.signal_sender, p.history.first().__class__) + + def test_pre_create_historical_m2m_records_signal(self): + def handler(sender, rows, history_instance, instance, field, **kwargs): + self.signal_was_called = True + self.signal_instance = instance + self.signal_history_instance = history_instance + self.signal_sender = sender + self.rows = rows + self.field = field + + pre_create_historical_m2m_records.connect(handler) + + p = PollWithManyToMany( + question="what's up?", + pub_date=today, + ) + p.save() + self.setUp() + p.places.add( + Place.objects.create(name="London"), Place.objects.create(name="Paris") + ) + + self.assertTrue(self.signal_was_called) + self.assertEqual(self.signal_instance, p) + self.assertIsNotNone(self.signal_history_instance) + self.assertEqual(self.signal_sender, p.history.first().places.model) + self.assertEqual(self.field, PollWithManyToMany._meta.many_to_many[0]) + self.assertEqual(len(self.rows), 2) + + def test_post_create_historical_m2m_records_signal(self): + def handler(sender, created_rows, history_instance, instance, field, **kwargs): + self.signal_was_called = True + self.signal_instance = instance + self.signal_history_instance = history_instance + self.signal_sender = sender + self.rows = created_rows + self.field = field + + post_create_historical_m2m_records.connect(handler) + + p = PollWithManyToMany( + question="what's up?", + pub_date=today, + ) + p.save() + self.setUp() + p.places.add( + Place.objects.create(name="London"), Place.objects.create(name="Paris") + ) + + self.assertTrue(self.signal_was_called) + self.assertEqual(self.signal_instance, p) + self.assertIsNotNone(self.signal_history_instance) + self.assertEqual(self.signal_sender, p.history.first().places.model) + self.assertEqual(self.field, PollWithManyToMany._meta.many_to_many[0]) + self.assertEqual(len(self.rows), 2)