Skip to content

Commit

Permalink
Merge pull request #1042 from legau/m2m-support
Browse files Browse the repository at this point in the history
M2M support for inheritance and signals
  • Loading branch information
valberg authored Oct 24, 2022
2 parents 7042a64 + a8c08b7 commit 52dafb6
Show file tree
Hide file tree
Showing 8 changed files with 391 additions and 53 deletions.
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Authors
- Klaas van Schelven
- Kris Neuharth
- Kyle Seever (`kseever <https://github.com/kseever>`_)
- Léni Gauffier (`legau <https://github.com/legau>`_)
- Leticia Portella
- Lucas Wiman
- Maciej "RooTer" Urbański
Expand Down
6 changes: 5 additions & 1 deletion docs/historical_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,15 @@ If you want to track many to many relationships, you need to define them explici
class Poll(models.Model):
question = models.CharField(max_length=200)
categories = models.ManyToManyField(Category)
history = HistoricalRecords(many_to_many=[categories])
history = HistoricalRecords(m2m_fields=[categories])
This will create a historical intermediate model that tracks each relational change
between `Poll` and `Category`.

You may also define these fields in a model attribute (by default on `_history_m2m_fields`).
This is mainly used for inherited models. You can override the attribute name by setting
your own `m2m_fields_model_field_name` argument on the `HistoricalRecord` instance.

You will see the many to many changes when diffing between two historical records:

.. code-block:: python
Expand Down
28 changes: 28 additions & 0 deletions docs/signals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@ saving a historical record. Arguments passed to the signals include the followin
using
The database alias being used

For Many To Many signals you've got the following :

.. glossary::
instance
The source model instance being saved

history_instance
The corresponding history record

rows (for pre_create)
The elements to be bulk inserted into the m2m table

created_rows (for post_create)
The created elements into the m2m table

field
The recorded field object

To connect the signals to your callbacks, you can use the ``@receiver`` decorator:

.. code-block:: python
Expand All @@ -30,6 +48,8 @@ To connect the signals to your callbacks, you can use the ``@receiver`` decorato
from simple_history.signals import (
pre_create_historical_record,
post_create_historical_record
pre_create_historical_m2m_records,
post_create_historical_m2m_records,
)
@receiver(pre_create_historical_record)
Expand All @@ -39,3 +59,11 @@ To connect the signals to your callbacks, you can use the ``@receiver`` decorato
@receiver(post_create_historical_record)
def post_create_historical_record_callback(sender, **kwargs):
print("Sent after saving historical record")
@receiver(pre_create_historical_m2m_records)
def pre_create_historical_m2m_records_callback(sender, **kwargs):
print("Sent before saving many to many field on historical record")
@receiver(post_create_historical_m2m_records)
def post_create_historical_m2m_records_callback(sender, **kwargs):
print("Sent after saving many to many field on historical record")
148 changes: 98 additions & 50 deletions simple_history/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@

from . import exceptions
from .manager import SIMPLE_HISTORY_REVERSE_ATTR_NAME, HistoryDescriptor
from .signals import post_create_historical_record, pre_create_historical_record
from .signals import (
post_create_historical_m2m_records,
post_create_historical_record,
pre_create_historical_m2m_records,
pre_create_historical_record,
)
from .utils import get_change_reason_from_object

try:
Expand Down Expand Up @@ -94,6 +99,8 @@ def __init__(
no_db_index=list(),
excluded_field_kwargs=None,
m2m_fields=(),
m2m_fields_model_field_name="_history_m2m_fields",
m2m_bases=(models.Model,),
):
self.user_set_verbose_name = verbose_name
self.user_set_verbose_name_plural = verbose_name_plural
Expand All @@ -114,6 +121,7 @@ def __init__(
self.related_name = related_name
self.use_base_model_db = use_base_model_db
self.m2m_fields = m2m_fields
self.m2m_fields_model_field_name = m2m_fields_model_field_name

if isinstance(no_db_index, str):
no_db_index = [no_db_index]
Expand All @@ -132,6 +140,12 @@ def __init__(
self.bases = (HistoricalChanges,) + tuple(bases)
except TypeError:
raise TypeError("The `bases` option must be a list or a tuple.")
try:
if isinstance(m2m_bases, str):
raise TypeError
self.m2m_bases = (HistoricalChanges,) + tuple(m2m_bases)
except TypeError:
raise TypeError("The `m2m_bases` option must be a list or a tuple.")

def contribute_to_class(self, cls, name):
self.manager_name = name
Expand Down Expand Up @@ -189,7 +203,10 @@ def finalize(self, sender, **kwargs):
# so the signal handlers can't use weak references.
models.signals.post_save.connect(self.post_save, sender=sender, weak=False)
models.signals.post_delete.connect(self.post_delete, sender=sender, weak=False)
for field in self.m2m_fields:

m2m_fields = self.get_m2m_fields_from_model(sender)

for field in m2m_fields:
m2m_changed.connect(
partial(self.m2m_changed, attr=field.name),
sender=field.remote_field.through,
Expand All @@ -200,13 +217,12 @@ def finalize(self, sender, **kwargs):
setattr(sender, self.manager_name, descriptor)
sender._meta.simple_history_manager_attribute = self.manager_name

for field in self.m2m_fields:
for field in m2m_fields:
m2m_model = self.create_history_m2m_model(
history_model, field.remote_field.through
)
self.m2m_models[field] = m2m_model

module = importlib.import_module(self.module)
setattr(module, m2m_model.__name__, m2m_model)

m2m_descriptor = HistoryDescriptor(m2m_model)
Expand Down Expand Up @@ -235,46 +251,18 @@ def get_history_model_name(self, model):
)

def create_history_m2m_model(self, model, through_model):
attrs = {
"__module__": self.module,
"__str__": lambda self: "{} as of {}".format(
self._meta.verbose_name, self.history.history_date
),
}

app_module = "%s.models" % model._meta.app_label

if model.__module__ != self.module:
# registered under different app
attrs["__module__"] = self.module
elif app_module != self.module:
# Abuse an internal API because the app registry is loading.
app = apps.app_configs[model._meta.app_label]
models_module = app.name
attrs["__module__"] = models_module

# Get the primary key to the history model this model will look up to
attrs["m2m_history_id"] = self._get_history_id_field()
attrs["history"] = models.ForeignKey(
model,
db_constraint=False,
on_delete=models.DO_NOTHING,
)
attrs["instance_type"] = through_model
attrs = {}

fields = self.copy_fields(through_model)
attrs.update(fields)
attrs.update(self.get_extra_fields_m2m(model, through_model, fields))

name = self.get_history_model_name(through_model)
registered_models[through_model._meta.db_table] = through_model
meta_fields = {"verbose_name": name}

if self.app:
meta_fields["app_label"] = self.app

attrs.update(Meta=type(str("Meta"), (), meta_fields))
attrs.update(Meta=type("Meta", (), self.get_meta_options_m2m(through_model)))

m2m_history_model = type(str(name), (models.Model,), attrs)
m2m_history_model = type(str(name), self.m2m_bases, attrs)

return m2m_history_model

Expand All @@ -285,7 +273,7 @@ def create_history_model(self, model, inherited):
attrs = {
"__module__": self.module,
"_history_excluded_fields": self.excluded_fields,
"_history_m2m_fields": self.m2m_fields,
"_history_m2m_fields": self.get_m2m_fields_from_model(model),
}

app_module = "%s.models" % model._meta.app_label
Expand Down Expand Up @@ -412,7 +400,7 @@ def _get_history_change_reason_field(self):

def _get_history_id_field(self):
if self.history_id_field:
history_id_field = self.history_id_field
history_id_field = self.history_id_field.clone()
history_id_field.primary_key = True
history_id_field.editable = False
elif getattr(settings, "SIMPLE_HISTORY_HISTORY_ID_USE_UUID", False):
Expand Down Expand Up @@ -465,6 +453,25 @@ def _get_history_related_field(self, model):
else:
return {}

def get_extra_fields_m2m(self, model, through_model, fields):
"""Return dict of extra fields added to the m2m historical record model"""

extra_fields = {
"__module__": model.__module__,
"__str__": lambda self: "{} as of {}".format(
self._meta.verbose_name, self.history.history_date
),
"history": models.ForeignKey(
model,
db_constraint=False,
on_delete=models.DO_NOTHING,
),
"instance_type": through_model,
"m2m_history_id": self._get_history_id_field(),
}

return extra_fields

def get_extra_fields(self, model, fields):
"""Return dict of extra fields added to the historical record model"""

Expand Down Expand Up @@ -577,6 +584,20 @@ def _date_indexing(self):
)
return result

def get_meta_options_m2m(self, through_model):
"""
Returns a dictionary of fields that will be added to
the Meta inner class of the m2m historical record model.
"""
name = self.get_history_model_name(through_model)

meta_fields = {"verbose_name": name}

if self.app:
meta_fields["app_label"] = self.app

return meta_fields

def get_meta_options(self, model):
"""
Returns a dictionary of fields that will be added to
Expand Down Expand Up @@ -637,7 +658,7 @@ def m2m_changed(self, instance, action, attr, pk_set, reverse, **_):
self.create_historical_record(instance, "~")

def create_historical_record_m2ms(self, history_instance, instance):
for field in self.m2m_fields:
for field in history_instance._history_m2m_fields:
m2m_history_model = self.m2m_models[field]
original_instance = history_instance.instance
through_model = getattr(original_instance, field.name).through
Expand All @@ -657,7 +678,21 @@ def create_historical_record_m2ms(self, history_instance, instance):
)
insert_rows.append(m2m_history_model(**insert_row))

m2m_history_model.objects.bulk_create(insert_rows)
pre_create_historical_m2m_records.send(
sender=m2m_history_model,
rows=insert_rows,
history_instance=history_instance,
instance=instance,
field=field,
)
created_rows = m2m_history_model.objects.bulk_create(insert_rows)
post_create_historical_m2m_records.send(
sender=m2m_history_model,
created_rows=created_rows,
history_instance=history_instance,
instance=instance,
field=field,
)

def create_historical_record(self, instance, history_type, using=None):
using = using if self.use_base_model_db else None
Expand Down Expand Up @@ -721,6 +756,14 @@ def get_history_user(self, instance):

return self.get_user(instance=instance, request=request)

def get_m2m_fields_from_model(self, model):
m2m_fields = set(self.m2m_fields)
try:
m2m_fields.update(getattr(model, self.m2m_fields_model_field_name))
except AttributeError:
pass
return [getattr(model, field.name).field for field in m2m_fields]


def transform_field(field):
"""Customize field appropriately for use in historical model"""
Expand Down Expand Up @@ -880,12 +923,20 @@ def diff_against(self, old_history, excluded_fields=None, included_fields=None):
if excluded_fields is None:
excluded_fields = set()

included_m2m_fields = {field.name for field in old_history._history_m2m_fields}
if included_fields is None:
included_fields = {
f.name for f in old_history.instance_type._meta.fields if f.editable
}
else:
included_m2m_fields = included_m2m_fields.intersection(included_fields)

fields = set(included_fields).difference(excluded_fields)
fields = (
set(included_fields)
.difference(included_m2m_fields)
.difference(excluded_fields)
)
m2m_fields = set(included_m2m_fields).difference(excluded_fields)

changes = []
changed_fields = []
Expand All @@ -902,11 +953,10 @@ def diff_against(self, old_history, excluded_fields=None, included_fields=None):
changed_fields.append(field)

# Separately compare m2m fields:
for field in old_history._history_m2m_fields:
for field in m2m_fields:
# First retrieve a single item to get the field names from:
reference_history_m2m_item = (
getattr(old_history, field.name).first()
or getattr(self, field.name).first()
getattr(old_history, field).first() or getattr(self, field).first()
)
history_field_names = []
if reference_history_m2m_item:
Expand All @@ -920,15 +970,13 @@ def diff_against(self, old_history, excluded_fields=None, included_fields=None):
if f.editable and f.name not in ["id", "m2m_history_id", "history"]
]

old_rows = list(
getattr(old_history, field.name).values(*history_field_names)
)
new_rows = list(getattr(self, field.name).values(*history_field_names))
old_rows = list(getattr(old_history, field).values(*history_field_names))
new_rows = list(getattr(self, field).values(*history_field_names))

if old_rows != new_rows:
change = ModelChange(field.name, old_rows, new_rows)
change = ModelChange(field, old_rows, new_rows)
changes.append(change)
changed_fields.append(field.name)
changed_fields.append(field)

return ModelDelta(changes, changed_fields, old_history, self)

Expand Down
8 changes: 8 additions & 0 deletions simple_history/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,11 @@
# Arguments: "instance", "history_instance", "history_date",
# "history_user", "history_change_reason", "using"
post_create_historical_record = django.dispatch.Signal()

# Arguments: "sender", "rows", "history_instance", "instance",
# "field"
pre_create_historical_m2m_records = django.dispatch.Signal()

# Arguments: "sender", "created_rows", "history_instance",
# "instance", "field"
post_create_historical_m2m_records = django.dispatch.Signal()
Loading

0 comments on commit 52dafb6

Please sign in to comment.