Skip to content

Commit

Permalink
feat: file upload support + better error handling (#6)
Browse files Browse the repository at this point in the history
* feat: file upload support + better error handling

* fix: http error handling

* chore: add storage docs
  • Loading branch information
drochetti authored Dec 15, 2023
1 parent 24f5310 commit 655c008
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 31 deletions.
6 changes: 3 additions & 3 deletions Sources/FalClient/Client+Codable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ public extension Client {
) async throws -> Output {
let inputData = input is EmptyInput ? nil : try encoder.encode(input)
let queryParams = inputData != nil && options.httpMethod == .get
? try JSONSerialization.jsonObject(with: inputData!) as? [String: Any]
: nil
? try Payload.create(fromJSON: inputData!)
: Payload.dict([:])

let url = buildUrl(fromId: app, path: options.path)
let data = try await sendRequest(url, input: inputData, queryParams: queryParams, options: options)
let data = try await sendRequest(to: url, input: inputData, queryParams: queryParams.asDictionary, options: options)
return try decoder.decode(Output.self, from: data)
}

Expand Down
31 changes: 29 additions & 2 deletions Sources/FalClient/Client+Request.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import Foundation

extension HTTPURLResponse {
/// Returns `true` if `statusCode` is in range 200...299.
/// Otherwise `false`.
var isSuccessful: Bool {
200 ... 299 ~= statusCode
}
}

extension Client {
func sendRequest(_ urlString: String, input: Data?, queryParams: [String: Any]? = nil, options: RunOptions) async throws -> Data {
func sendRequest(to urlString: String, input: Data?, queryParams: [String: Any]? = nil, options: RunOptions) async throws -> Data {
guard var url = URL(string: urlString) else {
throw FalError.invalidUrl(url: urlString)
}
Expand Down Expand Up @@ -42,10 +50,29 @@ extension Client {
if input != nil, options.httpMethod != .get {
request.httpBody = input
}
let (data, _) = try await URLSession.shared.data(for: request)
let (data, response) = try await URLSession.shared.data(for: request)
try checkResponseStatus(for: response, withData: data)
return data
}

func checkResponseStatus(for response: URLResponse, withData data: Data) throws {
guard let httpResponse = response as? HTTPURLResponse else {
throw FalError.invalidResultFormat
}
if let httpResponse = response as? HTTPURLResponse, !httpResponse.isSuccessful {
let errorPayload = try? Payload.create(fromJSON: data)
let statusCode = httpResponse.statusCode
let message = errorPayload?["detail"].stringValue
?? errorPayload?.stringValue
?? HTTPURLResponse.localizedString(forStatusCode: statusCode)
throw FalError.httpError(
status: statusCode,
message: message,
payload: errorPayload
)
}
}

var userAgent: String {
let osVersion = ProcessInfo.processInfo.operatingSystemVersionString
return "fal.ai/swift-client 0.1.0 - \(osVersion)"
Expand Down
2 changes: 2 additions & 0 deletions Sources/FalClient/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ public protocol Client {

var realtime: Realtime { get }

var storage: Storage { get }

func run(_ id: String, input: Payload?, options: RunOptions) async throws -> Payload

func subscribe(
Expand Down
16 changes: 12 additions & 4 deletions Sources/FalClient/FalClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,21 @@ public struct FalClient: Client {

public var realtime: Realtime { RealtimeClient(client: self) }

public var storage: Storage { StorageClient(client: self) }

public func run(_ app: String, input: Payload?, options: RunOptions) async throws -> Payload {
let inputData = input != nil ? try JSONEncoder().encode(input) : nil
var requestInput = input
if let storage = storage as? StorageClient,
let input,
options.httpMethod != .get,
input.hasBinaryData
{
requestInput = try await storage.autoUpload(input: input)
}
let queryParams = options.httpMethod == .get ? input : nil
let url = buildUrl(fromId: app, path: options.path)
let data = try await sendRequest(url, input: inputData, queryParams: queryParams?.asDictionary, options: options)
let decoder = JSONDecoder()
return try decoder.decode(Payload.self, from: data)
let data = try await sendRequest(to: url, input: requestInput?.json(), queryParams: queryParams?.asDictionary, options: options)
return try .create(fromJSON: data)
}

public func subscribe(
Expand Down
1 change: 1 addition & 0 deletions Sources/FalClient/FalError.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

enum FalError: Error {
case httpError(status: Int, message: String, payload: Payload?)
case invalidResultFormat
case invalidUrl(url: String)
case unauthorized(message: String)
Expand Down
68 changes: 61 additions & 7 deletions Sources/FalClient/Payload.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Foundation
import MessagePack

/// Represents a value that can be encoded and decoded. This data structure
/// is used to represent the input and output of the model API and closely
Expand Down Expand Up @@ -125,17 +126,33 @@ extension Payload: ExpressibleByDictionaryLiteral {

public extension Payload {
subscript(key: String) -> Payload {
if case let .dict(dict) = self, let value = dict[key] {
return value
get {
if case let .dict(dict) = self, let value = dict[key] {
return value
}
return .nilValue
}
set(newValue) {
if case var .dict(dict) = self {
dict[key] = newValue
self = .dict(dict)
}
}
return .nilValue
}

subscript(index: Int) -> Payload {
if case let .array(arr) = self, arr.indices.contains(index) {
return arr[index]
get {
if case let .array(arr) = self, arr.indices.contains(index) {
return arr[index]
}
return .nilValue
}
set(newValue) {
if case var .array(arr) = self {
arr[index] = newValue
self = .array(arr)
}
}
return .nilValue
}
}

Expand Down Expand Up @@ -181,7 +198,7 @@ extension Payload: Equatable {
}
}

// MARK: - Converto to native types
// MARK: - Convert to native types

extension Payload {
var nativeValue: Any {
Expand Down Expand Up @@ -214,3 +231,40 @@ extension Payload {
return value.mapValues { $0.nativeValue }
}
}

// MARK: - Codable utilities

public extension Payload {
static func create(fromJSON data: Data) throws -> Payload {
try JSONDecoder().decode(Payload.self, from: data)
}

static func create(fromBinary data: Data) throws -> Payload {
try MessagePackDecoder().decode(Payload.self, from: data)
}

func json() throws -> Data {
try JSONEncoder().encode(self)
}

func binary() throws -> Data {
try MessagePackEncoder().encode(self)
}
}

// MARK: - Utilities

extension Payload {
var hasBinaryData: Bool {
switch self {
case .data:
return true
case let .array(array):
return array.contains { $0.hasBinaryData }
case let .dict(dict):
return dict.values.contains { $0.hasBinaryData }
default:
return false
}
}
}
33 changes: 22 additions & 11 deletions Sources/FalClient/Realtime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func throttle<T>(_ function: @escaping (T) -> Void, throttleInterval: DispatchTi
}

public enum FalRealtimeError: Error {
case connectionError
case connectionError(code: Int? = nil)
case unauthorized
case invalidInput
case invalidResult
Expand All @@ -27,8 +27,8 @@ public enum FalRealtimeError: Error {
extension FalRealtimeError: LocalizedError {
public var errorDescription: String? {
switch self {
case .connectionError:
return NSLocalizedString("Connection error", comment: "FalRealtimeError.connectionError")
case let .connectionError(code):
return NSLocalizedString("Connection error (code: \(String(describing: code)))", comment: "FalRealtimeError.connectionError")
case .unauthorized:
return NSLocalizedString("Unauthorized", comment: "FalRealtimeError.unauthorized")
case .invalidInput:
Expand Down Expand Up @@ -209,14 +209,14 @@ class WebSocketConnection: NSObject, URLSessionWebSocketDelegate {
// TODO: improve app alias resolution
let appAlias = app.split(separator: "-").dropFirst().joined(separator: "-")
let url = "https://rest.alpha.fal.ai/tokens/"
let body = try? JSONSerialization.data(withJSONObject: [
"allowed_apps": [appAlias],
let body: Payload = [
"allowed_apps": [.string(appAlias)],
"token_expiration": 300,
])
]
do {
let response = try await self.client.sendRequest(
url,
input: body,
to: url,
input: body.json(),
options: .withMethod(.post)
)
if let token = String(data: response, encoding: .utf8) {
Expand Down Expand Up @@ -250,8 +250,15 @@ class WebSocketConnection: NSObject, URLSessionWebSocketDelegate {
self?.onError(error)
}
case let .failure(error):
self?.onError(error)
self?.task = nil
if let posixError = error as? POSIXError, posixError.code == .ENOTCONN {
// Ignore this error as it's thrown by Foundation's WebSocket implementation
// when messages were requested but the connection was closed already.
// This is safe to ignore, as the client is not expecting any other messages
// and will reconnect when new messages are sent.
return
}
self?.onError(error)
}
}
}
Expand All @@ -275,6 +282,7 @@ class WebSocketConnection: NSObject, URLSessionWebSocketDelegate {

func close() {
task?.cancel(with: .normalClosure, reason: "Programmatically closed".data(using: .utf8))
task = nil
}

func urlSession(
Expand All @@ -295,9 +303,12 @@ class WebSocketConnection: NSObject, URLSessionWebSocketDelegate {
func urlSession(
_: URLSession,
webSocketTask _: URLSessionWebSocketTask,
didCloseWith _: URLSessionWebSocketTask.CloseCode,
didCloseWith code: URLSessionWebSocketTask.CloseCode,
reason _: Data?
) {
if code != .normalClosure {
onError(FalRealtimeError.connectionError(code: code.rawValue))
}
task = nil
}
}
Expand All @@ -323,7 +334,7 @@ func isSuccessResult(_ message: Payload) -> Bool {
}

func getError(_ message: Payload) -> FalRealtimeError? {
if message["type"].stringValue != "x-fal-error",
if message["type"].stringValue == "x-fal-error",
let error = message["error"].stringValue,
let reason = message["reason"].stringValue
{
Expand Down
Loading

0 comments on commit 655c008

Please sign in to comment.