diff --git a/build.gradle.kts b/build.gradle.kts index 629945fb..cb32f827 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -15,6 +15,7 @@ repositories { dependencies { // Ktor dependencies + implementation(libs.ktor.server.conditionalHeaders) implementation(libs.ktor.server.contentNegotiation) implementation(libs.ktor.server.core) implementation(libs.ktor.server.cors) diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 6453768e..ed94d2f2 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -21,6 +21,7 @@ ktor-client-contentNegotiation = { module = "io.ktor:ktor-client-content-negotia ktor-client-core = { module = "io.ktor:ktor-client-core", version.ref = "ktor" } ktor-client-cio = { module = "io.ktor:ktor-client-cio", version.ref = "ktor" } ktor-serializationJson = { module = "io.ktor:ktor-serialization-kotlinx-json", version.ref = "ktor" } +ktor-server-conditionalHeaders = { module = "io.ktor:ktor-server-conditional-headers", version.ref = "ktor" } ktor-server-contentNegotiation = { module = "io.ktor:ktor-server-content-negotiation", version.ref = "ktor" } ktor-server-core = { module = "io.ktor:ktor-server-core", version.ref = "ktor" } ktor-server-cors = { module = "io.ktor:ktor-server-cors", version.ref = "ktor" } diff --git a/src/main/kotlin/server/endpoints/files/DownloadFileEndpoint.kt b/src/main/kotlin/server/endpoints/files/DownloadFileEndpoint.kt index aaea3a57..74c087a4 100644 --- a/src/main/kotlin/server/endpoints/files/DownloadFileEndpoint.kt +++ b/src/main/kotlin/server/endpoints/files/DownloadFileEndpoint.kt @@ -1,5 +1,6 @@ package server.endpoints.files +import io.ktor.http.HttpHeaders import io.ktor.server.response.header import io.ktor.server.response.respondFile import io.ktor.server.response.respondOutputStream @@ -10,7 +11,10 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext import server.endpoints.EndpointBase import server.error.Errors +import server.response.FileSource +import server.response.FileUUID import server.response.respondFailure +import storage.FileType import storage.Storage import utils.ImageUtils @@ -23,6 +27,13 @@ object DownloadFileEndpoint : EndpointBase("/download/{uuid}") { val file = Storage.find(uuid) ?: return respondFailure(Errors.FileNotFound) + // Add the file's UUID to the response + call.response.header(HttpHeaders.FileUUID, uuid) + + // Check if the file is an image or a track + val source = if (file.parentFile == Storage.ImagesDir) FileType.IMAGE else FileType.TRACK + call.response.header(HttpHeaders.FileSource, source.headerValue) + // Add the file's MIME type to the response withContext(Dispatchers.IO) { Files.probeContentType(file.toPath()) diff --git a/src/main/kotlin/server/plugins/ConditionalHeaders.kt b/src/main/kotlin/server/plugins/ConditionalHeaders.kt new file mode 100644 index 00000000..9d91c5ca --- /dev/null +++ b/src/main/kotlin/server/plugins/ConditionalHeaders.kt @@ -0,0 +1,36 @@ +package server.plugins + +import io.ktor.http.HttpHeaders +import io.ktor.http.content.EntityTagVersion +import io.ktor.server.http.content.LastModifiedVersion +import io.ktor.server.plugins.conditionalheaders.ConditionalHeadersConfig +import java.security.MessageDigest +import server.response.FileSource +import server.response.FileUUID +import storage.FileType +import storage.HashUtils +import storage.MessageDigestAlgorithm + +fun ConditionalHeadersConfig.configure() { + version { call, outgoingContent -> + val fileUUID = call.response.headers[HttpHeaders.FileUUID] + val fileSource = call.response.headers[HttpHeaders.FileSource] + val fileType = FileType.entries.find { it.headerValue == fileSource } + if (fileUUID != null && fileType != null) { + val file = fileType.fetcher(fileUUID)?.takeIf { it.exists() } + if (file != null) { + val modificationDate = file.lastModified() + val checkSumSha256 = HashUtils.getCheckSumFromFile( + MessageDigest.getInstance(MessageDigestAlgorithm.SHA_256), + file + ) + return@version listOf( + EntityTagVersion(checkSumSha256), + LastModifiedVersion(modificationDate) + ) + } + } + + emptyList() + } +} diff --git a/src/main/kotlin/server/plugins/Plugins.kt b/src/main/kotlin/server/plugins/Plugins.kt index a865f822..cb45d2e3 100644 --- a/src/main/kotlin/server/plugins/Plugins.kt +++ b/src/main/kotlin/server/plugins/Plugins.kt @@ -4,6 +4,7 @@ import database.serialization.Json import io.ktor.serialization.kotlinx.json.json import io.ktor.server.application.Application import io.ktor.server.application.install +import io.ktor.server.plugins.conditionalheaders.ConditionalHeaders import io.ktor.server.plugins.contentnegotiation.ContentNegotiation import io.ktor.server.plugins.statuspages.StatusPages @@ -16,6 +17,7 @@ import io.ktor.server.plugins.statuspages.StatusPages * @receiver The application on which this method is called. */ fun Application.installPlugins() { + install(ConditionalHeaders) { configure() } install(ContentNegotiation) { json(Json) } install(StatusPages) { configureStatusPages() } } diff --git a/src/main/kotlin/server/response/CustomResponseHeaders.kt b/src/main/kotlin/server/response/CustomResponseHeaders.kt new file mode 100644 index 00000000..acc39f07 --- /dev/null +++ b/src/main/kotlin/server/response/CustomResponseHeaders.kt @@ -0,0 +1,17 @@ +package server.response + +import io.ktor.http.HttpHeaders +import storage.FileType + +/** + * A header included in the responses of file requests, containing the UUID of the file. + */ +val HttpHeaders.FileUUID: String get() = "X-File-UUID" + +/** + * A header included in the responses of file requests, where the file comes from. Basically one of the following: + * - `Images` ([FileType.IMAGE]) + * - `Tracks` ([FileType.TRACK]) + * @see FileType + */ +val HttpHeaders.FileSource: String get() = "X-File-Source" diff --git a/src/main/kotlin/storage/FileType.kt b/src/main/kotlin/storage/FileType.kt new file mode 100644 index 00000000..78700dcb --- /dev/null +++ b/src/main/kotlin/storage/FileType.kt @@ -0,0 +1,8 @@ +package storage + +import java.io.File + +enum class FileType(val headerValue: String, val fetcher: (uuid: String) -> File?) { + IMAGE("Images", Storage::imageFile), + TRACK("Tracks", Storage::trackFile) +} diff --git a/src/main/kotlin/storage/Storage.kt b/src/main/kotlin/storage/Storage.kt index 5c3f2c0c..17aeecdb 100644 --- a/src/main/kotlin/storage/Storage.kt +++ b/src/main/kotlin/storage/Storage.kt @@ -17,7 +17,9 @@ object Storage { val ImagesDir by lazy { File(BaseDir, "images").also { if (!it.exists()) it.mkdirs() } } val TracksDir by lazy { File(BaseDir, "tracks").also { if (!it.exists()) it.mkdirs() } } - fun imageFile(path: String) = File(ImagesDir, path) + fun imageFile(uuid: String) = ImagesDir.listFiles().find { it.name.startsWith(uuid) } + + fun trackFile(uuid: String) = TracksDir.listFiles().find { it.name.startsWith(uuid) } /** * Finds a file based on the given UUID. diff --git a/src/test/kotlin/server/endpoints/files/TestFileDownloading.kt b/src/test/kotlin/server/endpoints/files/TestFileDownloading.kt index c2daf723..1d272359 100644 --- a/src/test/kotlin/server/endpoints/files/TestFileDownloading.kt +++ b/src/test/kotlin/server/endpoints/files/TestFileDownloading.kt @@ -4,10 +4,14 @@ import assertions.assertSuccess import database.entity.Area import io.ktor.client.statement.bodyAsChannel import io.ktor.client.statement.readRawBytes +import io.ktor.http.HttpHeaders +import io.ktor.http.etag import io.ktor.http.isSuccess +import io.ktor.http.lastModified import io.ktor.utils.io.readBuffer import java.awt.image.BufferedImage import java.io.File +import java.security.MessageDigest import javax.imageio.ImageIO import kotlin.test.Test import kotlin.test.assertEquals @@ -17,27 +21,32 @@ import kotlinx.io.copyTo import server.DataProvider import server.base.ApplicationTestBase import server.base.StubApplicationTestBuilder +import server.response.FileSource +import server.response.FileUUID +import storage.FileType +import storage.HashUtils +import storage.MessageDigestAlgorithm import storage.Storage class TestFileDownloading : ApplicationTestBase() { private suspend inline fun StubApplicationTestBuilder.provideImageFile( imageFile: String = "/images/alcoi.jpg", - block: (imageUUID: String) -> Unit + block: (imageUUID: String, imageFile: File) -> Unit ) { val areaId = DataProvider.provideSampleArea(this, imageFile = imageFile) - var image: String? = null + var imageFile: File? = null get("/area/$areaId").apply { assertSuccess { data -> assertNotNull(data) - image = data.image.toRelativeString(Storage.ImagesDir) + imageFile = data.image } } - assertNotNull(image) + assertNotNull(imageFile) - block(image) + block(imageFile.toRelativeString(Storage.ImagesDir), imageFile) } private fun downloadResized( @@ -46,7 +55,7 @@ class TestFileDownloading : ApplicationTestBase() { fetch: (BufferedImage) -> Int, imageFile: String = "/images/alcoi.jpg" ) = test { - provideImageFile(imageFile) { image -> + provideImageFile(imageFile) { image, _ -> val tempFile = File.createTempFile("eaic", null) val response = get("/download/$image?$argument=$value") assertTrue( @@ -69,14 +78,29 @@ class TestFileDownloading : ApplicationTestBase() { @Test fun `test downloading files`() = test { - provideImageFile { image -> + provideImageFile { image, imageFile -> get("/download/$image").apply { - headers["Content-Type"].let { contentType -> - assertEquals( - "image/jpeg", - contentType, - "Content-Type header is not JPEG. Got: $contentType" + headers[HttpHeaders.ContentType].let { contentType -> + assertEquals("image/jpeg", contentType, "Content-Type header is not JPEG. Got: $contentType") + } + assertEquals(image, headers[HttpHeaders.FileUUID], "File UUID header is not correct") + assertEquals( + FileType.IMAGE.headerValue, + headers[HttpHeaders.FileSource], + "File source header is not correct." + ) + etag().let { + val hash = HashUtils.getCheckSumFromFile( + MessageDigest.getInstance(MessageDigestAlgorithm.SHA_256), + imageFile ) + assertEquals("\"$hash\"", it, "ETag header is not correct") + } + lastModified()?.time.let { + // Value may be truncated to seconds and converted again to ms, so we need to truncate it + val fileLastModified = imageFile.lastModified() / 1000 * 1000 + val headerLastModified = it?.div(1000)?.times(1000) + assertEquals(fileLastModified, headerLastModified, "Last-Modified header is not correct") } readRawBytes() }