diff --git a/object_storage_api/core/exceptions.py b/object_storage_api/core/exceptions.py index 9afca44..4bfb11c 100644 --- a/object_storage_api/core/exceptions.py +++ b/object_storage_api/core/exceptions.py @@ -54,6 +54,13 @@ class InvalidImageFileError(BaseAPIException): response_detail = "File given is not a valid image" +class InvalidFilenameExtension(BaseAPIException): + """The provided filename's extension does not match expected file type.""" + + status_code = 422 + response_detail = "Filename does not contain the correct extension" + + class MissingRecordError(DatabaseError): """A specific database record was requested but could not be found.""" diff --git a/object_storage_api/services/image.py b/object_storage_api/services/image.py index 4c92914..4d9d6f9 100644 --- a/object_storage_api/services/image.py +++ b/object_storage_api/services/image.py @@ -4,12 +4,13 @@ """ import logging +import mimetypes from typing import Annotated, Optional from bson import ObjectId from fastapi import Depends, UploadFile -from object_storage_api.core.exceptions import InvalidObjectIdError +from object_storage_api.core.exceptions import InvalidFilenameExtension, InvalidObjectIdError from object_storage_api.core.image import generate_thumbnail_base64_str from object_storage_api.models.image import ImageIn from object_storage_api.repositories.image import ImageRepo @@ -51,6 +52,7 @@ def create(self, image_metadata: ImagePostMetadataSchema, upload_file: UploadFil :param upload_file: Upload file of the image to be created. :return: Created image with an pre-signed upload URL. :raises InvalidObjectIdError: If the image has any invalid ID's in it. + :raises InvalidFilenameExtension: If the image has a mismatched file extension. """ # Generate a unique ID for the image - this needs to be known now to avoid inserting into the database @@ -63,6 +65,12 @@ def create(self, image_metadata: ImagePostMetadataSchema, upload_file: UploadFil # Upload the full size image to object storage object_key = self._image_store.upload(image_id, image_metadata, upload_file) + expected_file_type = mimetypes.guess_type(upload_file.filename)[0] + if expected_file_type != upload_file.content_type: + raise InvalidFilenameExtension( + f"File extension `{upload_file.filename}` does not match content type `{upload_file.content_type}`" + ) + try: image_in = ImageIn( **image_metadata.model_dump(), @@ -109,10 +117,20 @@ def update(self, image_id: str, image: ImagePatchMetadataSchema) -> ImageMetadat :param image_id: The ID of the image to update. :param image: The image containing the fields to be updated. :return: The updated image. + :raises InvalidFilenameExtension: If the image has a mismatched file extension. """ stored_image = self._image_repository.get(image_id=image_id) update_data = image.model_dump(exclude_unset=True) + stored_type = mimetypes.guess_type(stored_image.file_name) + if image.file_name is not None: + update_type = mimetypes.guess_type(image.file_name) + if update_type != stored_type: + raise InvalidFilenameExtension( + f"Patch filename extension `{image.file_name}` does not match " + f"stored image `{stored_image.file_name}`" + ) + update_primary = image.primary is not None and image.primary is True and stored_image.primary is False updated_image = self._image_repository.update( image_id=image_id, diff --git a/test/e2e/test_image.py b/test/e2e/test_image.py index ae86437..83b6b92 100644 --- a/test/e2e/test_image.py +++ b/test/e2e/test_image.py @@ -50,6 +50,25 @@ def post_image(self, image_post_metadata_data: dict, file_name: str) -> Optional ) return self._post_response_image.json()["id"] if self._post_response_image.status_code == 201 else None + def post_image_with_file_extension_content_type_mismatch( + self, image_post_metadata_data: dict, file_name: str + ) -> Optional[str]: + """ + Posts an image with the given metadata and a test image file, changing the filename to a mismatched extension + and throwing an error. + + :param image_post_metadata_data: Dictionary containing the image metadata data as would be required for an + `ImagePostMetadataSchema`. + :param file_name: File name of the image to upload (relative to the 'test/files' directory). + :return: ID of the created image (or `None` if not successful). + """ + + with open(f"test/files/{file_name}", mode="rb") as file: + self._post_response_image = self.test_client.post( + "/images", data={**image_post_metadata_data}, files={"upload_file": ("image.png", file, "image/jpeg")} + ) + return self._post_response_image.json()["id"] if self._post_response_image.status_code == 201 else None + def check_post_image_success(self, expected_image_get_data: dict) -> None: """ Checks that a prior call to `post_image` gave a successful response with the expected data returned. @@ -100,6 +119,14 @@ def test_create_with_invalid_image_file(self): self.post_image(IMAGE_POST_METADATA_DATA_REQUIRED_VALUES_ONLY, "invalid_image.jpg") self.check_post_image_failed_with_detail(422, "File given is not a valid image") + def test_create_with_file_extension_content_type_mismatch(self): + """Test creating an image with a mismatched file extension.""" + + self.post_image_with_file_extension_content_type_mismatch( + IMAGE_POST_METADATA_DATA_REQUIRED_VALUES_ONLY, "image.jpg" + ) + self.check_post_image_failed_with_detail(422, "Filename does not contain the correct extension") + class GetDSL(CreateDSL): """Base class for get tests.""" @@ -364,6 +391,12 @@ def test_update_invalid_id(self): self.patch_image("invalid-id", {}) self.check_patch_image_failed_with_detail(404, "Image not found") + def test_partial_update_with_file_extension_content_type_mismatch(self): + """Test updating an image with a different extension.""" + image_id = self.post_image(IMAGE_POST_METADATA_DATA_ALL_VALUES, "image.jpg") + self.patch_image(image_id, {**IMAGE_PATCH_METADATA_DATA_ALL_VALUES, "file_name": "picture.png"}) + self.check_patch_image_failed_with_detail(422, "Filename does not contain the correct extension") + def test_update_primary(self): """Test updating primary to True, triggering other database updates.""" image_id_a = self.post_image({**IMAGE_POST_METADATA_DATA_ALL_VALUES}, "image.jpg") diff --git a/test/unit/services/test_image.py b/test/unit/services/test_image.py index 39fd332..f1a3180 100644 --- a/test/unit/services/test_image.py +++ b/test/unit/services/test_image.py @@ -14,7 +14,7 @@ from bson import ObjectId from fastapi import UploadFile -from object_storage_api.core.exceptions import InvalidObjectIdError +from object_storage_api.core.exceptions import InvalidFilenameExtension, InvalidObjectIdError from object_storage_api.models.image import ImageIn, ImageOut from object_storage_api.schemas.image import ( ImageMetadataSchema, @@ -71,16 +71,18 @@ class CreateDSL(ImageServiceDSL): _created_image: ImageMetadataSchema _create_exception: pytest.ExceptionInfo - def mock_create(self, image_post_metadata_data: dict) -> None: + def mock_create(self, image_post_metadata_data: dict, filename: str) -> None: """ Mocks repo & store methods appropriately to test the `create` service method. :param image_post_metadata_data: Dictionary containing the image metadata data as would be required for an `ImagePostMetadataSchema`. + :param filename: Filename of the image. """ self._image_post_metadata = ImagePostMetadataSchema(**image_post_metadata_data) - self._upload_file = UploadFile(MagicMock(), size=100, filename="test.png", headers=MagicMock()) + header = {"content-type": "image/png"} + self._upload_file = UploadFile(MagicMock(), size=100, filename=filename, headers=header) self._expected_image_id = ObjectId() self.mock_object_id.return_value = self._expected_image_id @@ -160,14 +162,24 @@ class TestCreate(CreateDSL): def test_create(self): """Test creating an image.""" - self.mock_create(IMAGE_POST_METADATA_DATA_ALL_VALUES) + self.mock_create(IMAGE_POST_METADATA_DATA_ALL_VALUES, "test.png") self.call_create() self.check_create_success() + def test_create_with_file_extension_content_type_mismatch(self): + """Test creating an image with an inconsistent file extension and content type.""" + + self.mock_create(IMAGE_POST_METADATA_DATA_ALL_VALUES, "test.jpeg") + self.call_create_expecting_error(InvalidFilenameExtension) + self.check_create_failed_with_exception( + f"File extension `{self._upload_file.filename}` does not match " + f"content type `{self._upload_file.content_type}`" + ) + def test_create_with_invalid_entity_id(self): """Test creating an image with an invalid `entity_id`.""" - self.mock_create({**IMAGE_POST_METADATA_DATA_ALL_VALUES, "entity_id": "invalid-id"}) + self.mock_create({**IMAGE_POST_METADATA_DATA_ALL_VALUES, "entity_id": "invalid-id"}, "test.png") self.call_create_expecting_error(InvalidObjectIdError) self.check_create_failed_with_exception("Invalid ObjectId value 'invalid-id'") @@ -267,12 +279,12 @@ class UpdateDSL(ImageServiceDSL): _expected_image_out: ImageOut _updated_image_id: str _updated_image: MagicMock + _update_exception: pytest.ExceptionInfo def mock_update(self, image_patch_data: dict, stored_image_post_data: Optional[dict]) -> None: """ Mocks the repository methods appropriately to test the `update` service method. - :param image_id: ID of the image to be updated. :param image_patch_data: Dictionary containing the patch data as would be required for an `ImagePatchMetadataSchema` (i.e. no created and modified times required). :param stored_image_post_data: Dictionary containing the image data for the existing stored @@ -316,6 +328,19 @@ def call_update(self, image_id: str) -> None: self._updated_image_id = image_id self._updated_image = self.image_service.update(image_id, self._image_patch) + def call_update_expecting_error(self, image_id: str, error_type: type[BaseException]) -> None: + """ + Class the `ImageService` `update` method with the appropriate data from a prior call to `mock_update`. + while expecting an error to be raised. + + :param image_id: ID of the image to be updated. + :param error_type: Expected exception to be raised. + """ + self._updated_image_id = image_id + with pytest.raises(error_type) as exc: + self.image_service.update(image_id, self._image_patch) + self._update_exception = exc + def check_update_success(self) -> None: """Checks that a prior call to `call_update` worked as updated.""" # Ensure obtained old image @@ -330,6 +355,19 @@ def check_update_success(self) -> None: assert self._updated_image == self._expected_image_out + def check_update_failed_with_exception(self, message: str) -> None: + """ + Checks that a prior call to `call_update_expecting_error` worked as expected, raising an exception + with the correct message. + + :param message: Message of the raised exception. + """ + + self.mock_image_repository.get.assert_called_once_with(image_id=self._updated_image_id) + self.mock_image_repository.update.assert_not_called() + + assert str(self._update_exception.value) == message + class TestUpdate(UpdateDSL): """Tests for updating an image.""" @@ -356,6 +394,20 @@ def test_update_primary(self): self.call_update(image_id) self.check_update_success() + def test_update_with_file_extension_content_type_mismatch(self): + """Test updating filename to one with a mismatched file extension.""" + image_id = str(ObjectId()) + + self.mock_update( + image_patch_data={**IMAGE_PATCH_METADATA_DATA_ALL_VALUES, "file_name": "picture.png"}, + stored_image_post_data=IMAGE_IN_DATA_ALL_VALUES, + ) + self.call_update_expecting_error(image_id, InvalidFilenameExtension) + self.check_update_failed_with_exception( + f"Patch filename extension `{self._image_patch.file_name}` " + f"does not match stored image `{self._stored_image.file_name}`" + ) + class DeleteDSL(ImageServiceDSL): """Base class for `delete` tests."""