Skip to content

Commit

Permalink
Fixup fileIO concurrency issues
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler committed Mar 16, 2024
1 parent dd36634 commit 4e6f14d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 69 deletions.
99 changes: 32 additions & 67 deletions Sources/Hummingbird/Files/FileIO.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,8 @@ public struct FileIO: Sendable {
/// - Returns: Response body
public func loadFile(path: String, context: some BaseRequestContext) async throws -> ResponseBody {
do {
let (handle, region) = try await self.fileIO.openFile(path: path, eventLoop: self.eventLoopGroup.any()).get()
context.logger.debug("[FileIO] GET", metadata: ["file": .string(path)])

if region.readableBytes > self.chunkSize {
return try self.streamFile(handle: handle, region: region, context: context)
} else {
// only close file handle for load, as streamer hasn't loaded data at this point
defer {
try? handle.close()
}
return try await self.loadFile(handle: handle, region: region, context: context)
}
let stat = try await fileIO.lstat(path: path)
return self.readFile(path: path, range: 0...numericCast(stat.st_size - 1), context: context)
} catch {
throw HTTPError(.notFound)
}
Expand All @@ -67,28 +57,12 @@ public struct FileIO: Sendable {
/// - range:Range defining how much of the file is to be loaded
/// - context: Context this request is being called in
/// - Returns: Response body plus file size
public func loadFile(path: String, range: ClosedRange<Int>, context: some BaseRequestContext) async throws -> (ResponseBody, Int) {
public func loadFile(path: String, range: ClosedRange<Int>, context: some BaseRequestContext) async throws -> ResponseBody {
do {
let (handle, region) = try await self.fileIO.openFile(path: path, eventLoop: self.eventLoopGroup.any()).get()
context.logger.debug("[FileIO] GET", metadata: ["file": .string(path)])

// work out region to load
let regionRange = region.readerIndex...region.endIndex
let range = range.clamped(to: regionRange)
// add one to upperBound as range is inclusive of upper bound
let loadRegion = FileRegion(fileHandle: handle, readerIndex: range.lowerBound, endIndex: range.upperBound + 1)

if loadRegion.readableBytes > self.chunkSize {
let stream = try self.streamFile(handle: handle, region: loadRegion, context: context)
return (stream, region.readableBytes)
} else {
// only close file handle for load, as streamer hasn't loaded data at this point
defer {
try? handle.close()
}
let buffer = try await self.loadFile(handle: handle, region: loadRegion, context: context)
return (buffer, region.readableBytes)
}
let stat = try await fileIO.lstat(path: path)
let fileRange: ClosedRange<Int> = 0...numericCast(stat.st_size - 1)
let range = range.clamped(to: fileRange)
return self.readFile(path: path, range: range, context: context)
} catch {
throw HTTPError(.notFound)
}
Expand All @@ -101,7 +75,11 @@ public struct FileIO: Sendable {
/// - contents: Request body to write.
/// - path: Path to write to
/// - logger: Logger
public func writeFile(contents: RequestBody, path: String, context: some BaseRequestContext) async throws {
public func writeFile<AS: AsyncSequence>(
contents: AS,
path: String,
context: some BaseRequestContext
) async throws where AS.Element == ByteBuffer {
let eventLoop = self.eventLoopGroup.any()
let handle = try await self.fileIO.openFile(path: path, mode: .write, flags: .allowFileCreation(), eventLoop: eventLoop).get()
defer {
Expand All @@ -113,41 +91,28 @@ public struct FileIO: Sendable {
}
}

/// Load file as ByteBuffer
func loadFile(handle: NIOFileHandle, region: FileRegion, context: some BaseRequestContext) async throws -> ResponseBody {
let buffer = try await self.fileIO.read(
fileHandle: handle,
fromOffset: Int64(region.readerIndex),
byteCount: region.readableBytes,
allocator: context.allocator,
eventLoop: self.eventLoopGroup.any()
).get()
return .init(byteBuffer: buffer)
}
/// Return response body that will read file
func readFile(path: String, range: ClosedRange<Int>, context: some BaseRequestContext) -> ResponseBody {
return ResponseBody(contentLength: range.count) { writer in
try await self.fileIO.withFileHandle(path: path, mode: .read) { handle in
let endOffset = range.endIndex
let chunkSize = 8 * 1024
var fileOffset = range.startIndex

/// Return streamer that will load file
func streamFile(handle: NIOFileHandle, region: FileRegion, context: some BaseRequestContext) throws -> ResponseBody {
let fileOffset = region.readerIndex
let endOffset = region.endIndex
return ResponseBody(contentLength: region.readableBytes) { writer in
let chunkSize = 8 * 1024
var fileOffset = fileOffset

while fileOffset < endOffset {
let bytesLeft = endOffset - fileOffset
let bytesToRead = Swift.min(chunkSize, bytesLeft)
let fileOffsetToRead = fileOffset
let buffer = try await self.fileIO.read(
fileHandle: handle,
fromOffset: Int64(fileOffsetToRead),
byteCount: bytesToRead,
allocator: context.allocator,
eventLoop: self.eventLoopGroup.any()
).get()
fileOffset += bytesToRead
try await writer.write(buffer)
while case .inRange(let offset) = fileOffset {
let bytesLeft = range.distance(from: fileOffset, to: endOffset)
let bytesToRead = Swift.min(chunkSize, bytesLeft)
let buffer = try await self.fileIO.read(
fileHandle: handle,
fromOffset: numericCast(offset),
byteCount: bytesToRead,
allocator: context.allocator,
eventLoop: self.eventLoopGroup.any()
).get()
fileOffset = range.index(fileOffset, offsetBy: bytesToRead)
try await writer.write(buffer)
}
}
try handle.close()
}
}
}
2 changes: 1 addition & 1 deletion Sources/Hummingbird/Files/FileMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ public struct FileMiddleware<Context: BaseRequestContext>: RouterMiddleware {
switch request.method {
case .get:
if let range {
let (body, _) = try await self.fileIO.loadFile(path: fullPath, range: range, context: context)
let body = try await self.fileIO.loadFile(path: fullPath, range: range, context: context)
return Response(status: .partialContent, headers: headers, body: body)
}

Expand Down
5 changes: 4 additions & 1 deletion Tests/HummingbirdTests/FileMiddlewareTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import Hummingbird
import HummingbirdTesting
import XCTest

class HummingbirdFilesTests: XCTestCase {
class FileMiddlewareTests: XCTestCase {
func randomBuffer(size: Int) -> ByteBuffer {
var data = [UInt8](repeating: 0, count: size)
data = data.map { _ in UInt8.random(in: 0...255) }
Expand Down Expand Up @@ -89,13 +89,15 @@ class HummingbirdFilesTests: XCTestCase {
let slice = buffer.getSlice(at: 100, length: 3900)
XCTAssertEqual(response.body, slice)
XCTAssertEqual(response.headers[.contentRange], "bytes 100-3999/326000")
XCTAssertEqual(response.headers[.contentLength], "3900")
XCTAssertEqual(response.headers[.contentType], "text/plain")
}

try await client.execute(uri: filename, method: .get, headers: [.range: "bytes=0-0"]) { response in
let slice = buffer.getSlice(at: 0, length: 1)
XCTAssertEqual(response.body, slice)
XCTAssertEqual(response.headers[.contentRange], "bytes 0-0/326000")
XCTAssertEqual(response.headers[.contentLength], "1")
XCTAssertEqual(response.headers[.contentType], "text/plain")
}

Expand All @@ -109,6 +111,7 @@ class HummingbirdFilesTests: XCTestCase {
try await client.execute(uri: filename, method: .get, headers: [.range: "bytes=6000-"]) { response in
let slice = buffer.getSlice(at: 6000, length: 320_000)
XCTAssertEqual(response.body, slice)
XCTAssertEqual(response.headers[.contentLength], "320000")
XCTAssertEqual(response.headers[.contentRange], "bytes 6000-325999/326000")
}
}
Expand Down

0 comments on commit 4e6f14d

Please sign in to comment.