From 4e6f14d71005086eaa0ee688726501cd1f1ac568 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Sat, 16 Mar 2024 08:48:36 +0000 Subject: [PATCH] Fixup fileIO concurrency issues --- Sources/Hummingbird/Files/FileIO.swift | 99 ++++++------------- .../Hummingbird/Files/FileMiddleware.swift | 2 +- .../FileMiddlewareTests.swift | 5 +- 3 files changed, 37 insertions(+), 69 deletions(-) diff --git a/Sources/Hummingbird/Files/FileIO.swift b/Sources/Hummingbird/Files/FileIO.swift index c6c70979f..ef92a475b 100644 --- a/Sources/Hummingbird/Files/FileIO.swift +++ b/Sources/Hummingbird/Files/FileIO.swift @@ -41,18 +41,8 @@ public struct FileIO: Sendable { /// - Returns: Response body public func loadFile(path: String, context: some BaseRequestContext) async throws -> ResponseBody { do { - let (handle, region) = try await self.fileIO.openFile(path: path, eventLoop: self.eventLoopGroup.any()).get() - context.logger.debug("[FileIO] GET", metadata: ["file": .string(path)]) - - if region.readableBytes > self.chunkSize { - return try self.streamFile(handle: handle, region: region, context: context) - } else { - // only close file handle for load, as streamer hasn't loaded data at this point - defer { - try? handle.close() - } - return try await self.loadFile(handle: handle, region: region, context: context) - } + let stat = try await fileIO.lstat(path: path) + return self.readFile(path: path, range: 0...numericCast(stat.st_size - 1), context: context) } catch { throw HTTPError(.notFound) } @@ -67,28 +57,12 @@ public struct FileIO: Sendable { /// - range:Range defining how much of the file is to be loaded /// - context: Context this request is being called in /// - Returns: Response body plus file size - public func loadFile(path: String, range: ClosedRange, context: some BaseRequestContext) async throws -> (ResponseBody, Int) { + public func loadFile(path: String, range: ClosedRange, context: some BaseRequestContext) async throws -> ResponseBody { do { - let (handle, region) = try await self.fileIO.openFile(path: path, eventLoop: self.eventLoopGroup.any()).get() - context.logger.debug("[FileIO] GET", metadata: ["file": .string(path)]) - - // work out region to load - let regionRange = region.readerIndex...region.endIndex - let range = range.clamped(to: regionRange) - // add one to upperBound as range is inclusive of upper bound - let loadRegion = FileRegion(fileHandle: handle, readerIndex: range.lowerBound, endIndex: range.upperBound + 1) - - if loadRegion.readableBytes > self.chunkSize { - let stream = try self.streamFile(handle: handle, region: loadRegion, context: context) - return (stream, region.readableBytes) - } else { - // only close file handle for load, as streamer hasn't loaded data at this point - defer { - try? handle.close() - } - let buffer = try await self.loadFile(handle: handle, region: loadRegion, context: context) - return (buffer, region.readableBytes) - } + let stat = try await fileIO.lstat(path: path) + let fileRange: ClosedRange = 0...numericCast(stat.st_size - 1) + let range = range.clamped(to: fileRange) + return self.readFile(path: path, range: range, context: context) } catch { throw HTTPError(.notFound) } @@ -101,7 +75,11 @@ public struct FileIO: Sendable { /// - contents: Request body to write. /// - path: Path to write to /// - logger: Logger - public func writeFile(contents: RequestBody, path: String, context: some BaseRequestContext) async throws { + public func writeFile( + contents: AS, + path: String, + context: some BaseRequestContext + ) async throws where AS.Element == ByteBuffer { let eventLoop = self.eventLoopGroup.any() let handle = try await self.fileIO.openFile(path: path, mode: .write, flags: .allowFileCreation(), eventLoop: eventLoop).get() defer { @@ -113,41 +91,28 @@ public struct FileIO: Sendable { } } - /// Load file as ByteBuffer - func loadFile(handle: NIOFileHandle, region: FileRegion, context: some BaseRequestContext) async throws -> ResponseBody { - let buffer = try await self.fileIO.read( - fileHandle: handle, - fromOffset: Int64(region.readerIndex), - byteCount: region.readableBytes, - allocator: context.allocator, - eventLoop: self.eventLoopGroup.any() - ).get() - return .init(byteBuffer: buffer) - } + /// Return response body that will read file + func readFile(path: String, range: ClosedRange, context: some BaseRequestContext) -> ResponseBody { + return ResponseBody(contentLength: range.count) { writer in + try await self.fileIO.withFileHandle(path: path, mode: .read) { handle in + let endOffset = range.endIndex + let chunkSize = 8 * 1024 + var fileOffset = range.startIndex - /// Return streamer that will load file - func streamFile(handle: NIOFileHandle, region: FileRegion, context: some BaseRequestContext) throws -> ResponseBody { - let fileOffset = region.readerIndex - let endOffset = region.endIndex - return ResponseBody(contentLength: region.readableBytes) { writer in - let chunkSize = 8 * 1024 - var fileOffset = fileOffset - - while fileOffset < endOffset { - let bytesLeft = endOffset - fileOffset - let bytesToRead = Swift.min(chunkSize, bytesLeft) - let fileOffsetToRead = fileOffset - let buffer = try await self.fileIO.read( - fileHandle: handle, - fromOffset: Int64(fileOffsetToRead), - byteCount: bytesToRead, - allocator: context.allocator, - eventLoop: self.eventLoopGroup.any() - ).get() - fileOffset += bytesToRead - try await writer.write(buffer) + while case .inRange(let offset) = fileOffset { + let bytesLeft = range.distance(from: fileOffset, to: endOffset) + let bytesToRead = Swift.min(chunkSize, bytesLeft) + let buffer = try await self.fileIO.read( + fileHandle: handle, + fromOffset: numericCast(offset), + byteCount: bytesToRead, + allocator: context.allocator, + eventLoop: self.eventLoopGroup.any() + ).get() + fileOffset = range.index(fileOffset, offsetBy: bytesToRead) + try await writer.write(buffer) + } } - try handle.close() } } } diff --git a/Sources/Hummingbird/Files/FileMiddleware.swift b/Sources/Hummingbird/Files/FileMiddleware.swift index 50d381c8b..91ef29b32 100644 --- a/Sources/Hummingbird/Files/FileMiddleware.swift +++ b/Sources/Hummingbird/Files/FileMiddleware.swift @@ -201,7 +201,7 @@ public struct FileMiddleware: RouterMiddleware { switch request.method { case .get: if let range { - let (body, _) = try await self.fileIO.loadFile(path: fullPath, range: range, context: context) + let body = try await self.fileIO.loadFile(path: fullPath, range: range, context: context) return Response(status: .partialContent, headers: headers, body: body) } diff --git a/Tests/HummingbirdTests/FileMiddlewareTests.swift b/Tests/HummingbirdTests/FileMiddlewareTests.swift index 404ce0574..afb95ccfd 100644 --- a/Tests/HummingbirdTests/FileMiddlewareTests.swift +++ b/Tests/HummingbirdTests/FileMiddlewareTests.swift @@ -18,7 +18,7 @@ import Hummingbird import HummingbirdTesting import XCTest -class HummingbirdFilesTests: XCTestCase { +class FileMiddlewareTests: XCTestCase { func randomBuffer(size: Int) -> ByteBuffer { var data = [UInt8](repeating: 0, count: size) data = data.map { _ in UInt8.random(in: 0...255) } @@ -89,6 +89,7 @@ class HummingbirdFilesTests: XCTestCase { let slice = buffer.getSlice(at: 100, length: 3900) XCTAssertEqual(response.body, slice) XCTAssertEqual(response.headers[.contentRange], "bytes 100-3999/326000") + XCTAssertEqual(response.headers[.contentLength], "3900") XCTAssertEqual(response.headers[.contentType], "text/plain") } @@ -96,6 +97,7 @@ class HummingbirdFilesTests: XCTestCase { let slice = buffer.getSlice(at: 0, length: 1) XCTAssertEqual(response.body, slice) XCTAssertEqual(response.headers[.contentRange], "bytes 0-0/326000") + XCTAssertEqual(response.headers[.contentLength], "1") XCTAssertEqual(response.headers[.contentType], "text/plain") } @@ -109,6 +111,7 @@ class HummingbirdFilesTests: XCTestCase { try await client.execute(uri: filename, method: .get, headers: [.range: "bytes=6000-"]) { response in let slice = buffer.getSlice(at: 6000, length: 320_000) XCTAssertEqual(response.body, slice) + XCTAssertEqual(response.headers[.contentLength], "320000") XCTAssertEqual(response.headers[.contentRange], "bytes 6000-325999/326000") } }