diff --git a/oscar_odin/mappings/_model_mapper.py b/oscar_odin/mappings/_model_mapper.py new file mode 100644 index 0000000..5197807 --- /dev/null +++ b/oscar_odin/mappings/_model_mapper.py @@ -0,0 +1,69 @@ +"""Extended model mapper for Django models.""" +from typing import Sequence + +from django.db.models import ManyToManyRel, ManyToOneRel, OneToOneRel +from odin.mapping import MappingBase, MappingMeta +from odin.utils import getmeta + + +class ModelMappingMeta(MappingMeta): + """Extended type of mapping meta.""" + + def __new__(cls, name, bases, attrs): + mapping_type = super().__new__(cls, name, bases, attrs) + + if mapping_type.to_obj is None: + return mapping_type + + # Extract out foreign field types. + mapping_type.one_to_one_fields = one_to_one_fields = [] + mapping_type.many_to_one_fields = many_to_one_fields = [] + mapping_type.many_to_many_fields = many_to_many_fields = [] + for relation in getmeta(mapping_type.to_obj).related_objects: + if isinstance(relation, OneToOneRel): + one_to_one_fields.append(relation.related_name) + elif isinstance(relation, ManyToOneRel): + many_to_one_fields.append(relation.related_name) + elif isinstance(relation, ManyToManyRel): + many_to_many_fields.append(relation.related_name) + + return mapping_type + + +class ModelMapping(MappingBase, metaclass=ModelMappingMeta): + """Definition of a mapping between two Objects.""" + + exclude_fields = [] + mappings = [] + one_to_one_fields: Sequence[str] = [] + many_to_one_fields: Sequence[str] = [] + many_to_many_fields: Sequence[str] = [] + + def create_object(self, **field_values): + """Create a new product model.""" + + [ + (name, field_values.pop(name)) + for name in self.one_to_one_fields + if name in field_values + ] + many_to_one_values = [ + (name, field_values.pop(name)) + for name in self.many_to_one_fields + if name in field_values + ] + [ + (name, field_values.pop(name)) + for name in self.many_to_many_fields + if name in field_values + ] + + obj = super().create_object(**field_values) + + # TODO add one_to_one_values + for name, value in many_to_one_values: + if value: + getattr(obj, name).set(value) + # TODO add many_to_many_values + + return obj diff --git a/oscar_odin/mappings/catalogue.py b/oscar_odin/mappings/catalogue.py index a98d1b3..0867127 100644 --- a/oscar_odin/mappings/catalogue.py +++ b/oscar_odin/mappings/catalogue.py @@ -4,6 +4,7 @@ import odin from django.contrib.auth.models import AbstractUser +from django.db import transaction from django.db.models import QuerySet from django.db.models.fields.files import ImageFieldFile from django.http import HttpRequest @@ -13,6 +14,7 @@ from .. import resources from ..resources.catalogue import Structure from ._common import map_queryset +from ._model_mapper import ModelMapping __all__ = ( "ProductImageToResource", @@ -205,17 +207,17 @@ def map_stock_price(self) -> Tuple[Decimal, str, int]: return Decimal(0), "", 0 -class ProductToModel(odin.Mapping): +class ProductToModel(ModelMapping): """Map from a product resource to a model.""" from_obj = resources.catalogue.Product to_obj = ProductModel - # @odin.assign_field - # def images(self) -> List[ProductImageModel]: - # """Map related image.""" - # return list(ProductImageToModel.apply(self.source.images, context=self.context)) - # + @odin.map_list_field + def images(self, values) -> List[ProductImageModel]: + """Map related image.""" + return list(ProductImageToModel.apply(values, context=self.context)) + # @odin.assign_field # def categories(self) -> List[CategoryModel]: # """Map related categories.""" @@ -307,3 +309,22 @@ def product_to_model( """Map a product resource to a model.""" model = ProductToModel.apply(product) return model + + +def product_to_db( + product: resources.catalogue.Product, +) -> ProductModel: + """Map a product resource to a model and store in the database. + + The method will handle the nested database saves required to store the entire resource + within a single transaction. + """ + model: ProductModel = product_to_model(product) + + with transaction.atomic(): + model.save() + for image in product.images: + image.product = model + image.save() + + return model diff --git a/tests/mappings/test_catalogue.py b/tests/mappings/test_catalogue.py index 0f8f3f0..08d56ac 100644 --- a/tests/mappings/test_catalogue.py +++ b/tests/mappings/test_catalogue.py @@ -42,12 +42,10 @@ def test_product_to_resource__where_is_a_parent_product_include_children(self): self.assertIsNotNone(actual.children) self.assertEqual(3, len(actual.children)) - def test_product_to_model__basic_model_to_resource(self): + def test_product_to_db__basic_model_to_resource(self): product = Product.objects.first() resource = catalogue.product_to_resource(product) - actual = catalogue.product_to_model(resource) - - actual.save() + actual = catalogue.product_to_db(resource) self.assertEqual(resource.title, actual.title)