diff --git a/oscar_odin/mappings/context.py b/oscar_odin/mappings/context.py index 8ebb918..2cd9593 100644 --- a/oscar_odin/mappings/context.py +++ b/oscar_odin/mappings/context.py @@ -84,25 +84,28 @@ def __bool__(self): def prepare_instance_for_validation(self, instance): return instance + def get_identity(self, instance, identifiers): + if not identifiers: + return + return attrgetter(*identifiers)(instance) + def validate_instances(self, instances, validate_unique=True, fields=None): - if not self.clean_instances: + if not self.clean_instances or not instances: return instances validated_instances = [] identities = [] exclude = () - if fields and instances: + if fields: all_fields = instances[0]._meta.fields exclude = [f.name for f in all_fields if f.name not in fields] - try: - identifier = self.identifier_mapping.get(instances[0].__class__)[0] - except (IndexError, TypeError): - identifier = None + identifiers = self.identifier_mapping.get(instances[0].__class__) for instance in instances: - if identifier is None or getattr(instance, identifier) not in identities: - if identifier is not None: - identities.append(getattr(instance, identifier)) + identity = self.get_identity(instance, identifiers) + if identifiers is None or identity not in identities: + if identifiers is not None: + identities.append(identity) try: instance = self.prepare_instance_for_validation(instance) instance.full_clean( @@ -277,6 +280,8 @@ def bulk_update_or_create_one_to_many(self): if self.delete_related: for relation, keys in identities.items(): + # instance_identifier here is product upc, if multiple identifiers for + # a product are used, then the following code must be updated. instance_identifier = self.identifier_mapping.get( relation.remote_field.related_model )[0] diff --git a/tests/reverse/test_deleting_related.py b/tests/reverse/test_deleting_related.py index a6eb025..a31ee35 100644 --- a/tests/reverse/test_deleting_related.py +++ b/tests/reverse/test_deleting_related.py @@ -77,7 +77,7 @@ def test_deleting_product_related_models(self): description="description", structure=Product.STANDALONE, is_discountable=True, - price=D("20"), + price=D("10"), availability=2, currency="EUR", partner=partner, @@ -124,7 +124,9 @@ def test_deleting_product_related_models(self): ), ] + self.assertEqual(Stockrecord.objects.count(), 0) _, errors = products_to_db(product_resources) + self.assertEqual(Stockrecord.objects.count(), 2) self.assertEqual(len(errors), 0) prd = Product.objects.get(upc="1234323-2") prd_563 = Product.objects.get(upc="563-2") @@ -135,6 +137,10 @@ def test_deleting_product_related_models(self): self.assertEqual(prd.stockrecords.count(), 1) self.assertTrue(prd.stockrecords.filter(partner=partner).exists()) + self.assertEqual(prd.stockrecords.first().price, D("10")) + self.assertEqual(prd_563.stockrecords.count(), 1) + self.assertTrue(prd_563.stockrecords.filter(partner=partner).exists()) + self.assertEqual(prd_563.stockrecords.first().price, D("20")) self.assertEqual(prd.categories.count(), 2) self.assertTrue(prd.categories.filter(code="1").exists())