diff --git a/simple_history/tests/tests/test_utils.py b/simple_history/tests/tests/test_utils.py index 7db701d9..ed034f98 100644 --- a/simple_history/tests/tests/test_utils.py +++ b/simple_history/tests/tests/test_utils.py @@ -37,6 +37,7 @@ get_m2m_reverse_field_name, update_change_reason, ) +from .utils import db_supports_returning_autofield_pks User = get_user_model() @@ -155,6 +156,28 @@ def test_bulk_create_history_with_disabled_setting(self): self.assertEqual(Poll.objects.count(), 5) self.assertEqual(Poll.history.count(), 0) + def test_bulk_create_history_without_pks(self): + for poll in self.data: + poll.pk = None + + # An extra query must be made on some DBs to retrieve the auto-generated PKs + with self.assertNumQueries(2 if db_supports_returning_autofield_pks() else 3): + bulk_create_with_history(self.data, Poll) + + self.assertEqual(Poll.objects.count(), 5) + self.assertEqual(Poll.history.count(), 5) + + def test_bulk_create_history_without_some_pks(self): + self.data[1].pk = None + self.data[3].pk = None + + # An extra query must be made on some DBs to retrieve the auto-generated PKs + with self.assertNumQueries(3 if db_supports_returning_autofield_pks() else 4): + bulk_create_with_history(self.data, Poll) + + self.assertEqual(Poll.objects.count(), 5) + self.assertEqual(Poll.history.count(), 5) + def test_bulk_create_history_alternative_manager(self): bulk_create_with_history( self.data, diff --git a/simple_history/tests/tests/utils.py b/simple_history/tests/tests/utils.py index ae6fe949..ee9c33ba 100644 --- a/simple_history/tests/tests/utils.py +++ b/simple_history/tests/tests/utils.py @@ -1,10 +1,20 @@ +import sys from enum import Enum from typing import Type from django.conf import settings from django.db.models import Model +from django.db import connection from django.test import TestCase + +# DEV: Replace this with just `from functools import cache` +# when support for Python 3.8 has been dropped +if sys.version_info < (3, 9): + from functools import lru_cache as cache +else: + from functools import cache + request_middleware = "simple_history.middleware.HistoryRequestMiddleware" OTHER_DB_NAME = "other" @@ -35,6 +45,16 @@ def assertRecordValues(self, record, klass: Type[Model], values_dict: dict): self.assertEqual(getattr(record.history_object, key), value) +@cache +def db_supports_returning_autofield_pks() -> bool: + # See https://docs.djangoproject.com/en/stable/ref/models/querysets/#bulk-create + return connection.display_name.lower() in { + "postgresql", + "mariadb", + "sqlite", + } + + class TestDbRouter: def db_for_read(self, model, **hints): if model._meta.app_label == "external": diff --git a/simple_history/utils.py b/simple_history/utils.py index a5bafeaf..e2b73c28 100644 --- a/simple_history/utils.py +++ b/simple_history/utils.py @@ -125,7 +125,11 @@ def bulk_create_with_history( objs_with_id = model_manager.bulk_create( objs, batch_size=batch_size, ignore_conflicts=ignore_conflicts ) - if objs_with_id and objs_with_id[0].pk and not ignore_conflicts: + if ( + objs_with_id + and all(obj.pk is not None for obj in objs_with_id) + and not ignore_conflicts + ): second_transaction_required = False history_manager.bulk_history_create( objs_with_id,