Skip to content

Commit

Permalink
Migrate from PromiseKit to Swift async/await for turms-client-swift +…
Browse files Browse the repository at this point in the history
… Support sending requests simultaneously
  • Loading branch information
JamesChenX committed Aug 3, 2024
1 parent c21188d commit ebaad9f
Show file tree
Hide file tree
Showing 26 changed files with 1,520 additions and 1,699 deletions.
2 changes: 1 addition & 1 deletion turms-client-swift/Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ let package = Package(
targets: [
.target(
name: "TurmsClient",
dependencies: ["PromiseKit", "SwiftProtobuf"]
dependencies: ["SwiftProtobuf"]
),
.testTarget(
name: "TurmsClientTests",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import Foundation
import PromiseKit

public class BaseService {
let stateStore: StateStore
Expand All @@ -8,9 +7,7 @@ public class BaseService {
self.stateStore = stateStore
}

func close() -> Promise<Void> {
return Promise.value(())
}
func close() async {}

func onDisconnected(_: Error? = nil) {}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import Foundation
import PromiseKit

private class MessageDecoder {
private static let maxReadBufferCapacity = 8 * 1024 * 1024
Expand Down Expand Up @@ -73,12 +72,13 @@ public class ConnectionService: BaseService {
private let initialPort: UInt16
private let initialConnectTimeout: TimeInterval

private var disconnectPromises: [Resolver<Void>] = []
private var disconnectContinuations: [UnsafeContinuation<Void, Never>] = []

private var onConnectedListeners: [() -> Void] = []
private var onDisconnectedListeners: [(Error?) -> Void] = []
private var messageListeners: [(Data) -> Void] = []

private let lock = Lock()
private let decoder = MessageDecoder()

init(stateStore: StateStore, host: String? = nil, port: UInt16? = nil, connectTimeout: TimeInterval? = nil) {
Expand All @@ -89,7 +89,7 @@ public class ConnectionService: BaseService {
}

private func resetStates() {
fulfillDisconnectPromises()
fulfillDisconnectContinuations()
}

// Listeners
Expand Down Expand Up @@ -124,48 +124,43 @@ public class ConnectionService: BaseService {
}
}

private func fulfillDisconnectPromises() {
repeat {
disconnectPromises.popLast()?.fulfill(())
} while !disconnectPromises.isEmpty
private func fulfillDisconnectContinuations() {
lock.locked {
repeat {
disconnectContinuations.popLast()?.resume()
} while !disconnectContinuations.isEmpty
}
}

// Connection

public func connect(host: String? = nil, port: UInt16? = nil, connectTimeout _: TimeInterval? = nil, useTls: Bool? = false, certificatePinning: CertificatePinning? = nil) -> Promise<Void> {
return Promise { seal in
if stateStore.isConnected {
seal.reject(ResponseError(code: .clientSessionAlreadyEstablished))
return
}
resetStates()
let tcp = TcpClient(onClosed: { [weak self] error in
self?.onSocketClosed(error)
}, onDataReceived: { [weak self] data in
guard let s = self else { return }
let messages = try s.decoder.decodeMessages(data)
for message in messages {
s.notifyMessageListeners(message)
}
})
tcp.connect(host: host ?? initialHost, port: port ?? initialPort, useTls: useTls ?? false, certificatePinning: certificatePinning)
.done { [weak self] in
self?.onSocketOpened()
seal.fulfill_()
}.catch { error in
seal.reject(error)
}
stateStore.tcp = tcp
public func connect(host: String? = nil, port: UInt16? = nil, connectTimeout _: TimeInterval? = nil, useTls: Bool? = false, certificatePinning: CertificatePinning? = nil) async throws {
if stateStore.isConnected {
throw ResponseError(code: .clientSessionAlreadyEstablished)
}
resetStates()
let tcp = TcpClient(onClosed: { [weak self] error in
self?.onSocketClosed(error)
}, onDataReceived: { [weak self] data in
guard let s = self else { return }
let messages = try s.decoder.decodeMessages(data)
for message in messages {
s.notifyMessageListeners(message)
}
})
try await tcp.connect(host: host ?? initialHost, port: port ?? initialPort, useTls: useTls ?? false, certificatePinning: certificatePinning)
onSocketOpened()
stateStore.tcp = tcp
}

public func disconnect() -> Promise<Void> {
return Promise { seal in
public func disconnect() async {
await withUnsafeContinuation { continuation in
if !stateStore.isConnected {
seal.fulfill(())
return
return continuation.resume()
}
lock.locked {
disconnectContinuations.append(continuation)
}
disconnectPromises.append(seal)
stateStore.tcp!.close()
}
}
Expand All @@ -180,13 +175,13 @@ public class ConnectionService: BaseService {
private func onSocketClosed(_ error: Error?) {
decoder.clear()
stateStore.isConnected = false
fulfillDisconnectPromises()
fulfillDisconnectContinuations()
notifyOnDisconnectedListeners(error)
}

// Base methods

override func close() -> Promise<Void> {
return disconnect()
override func close() async {
return await disconnect()
}
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import Foundation
import PromiseKit

private class RequestContext {
let continuation: UnsafeContinuation<TurmsNotification, any Error>
let timeoutTask: Task<Void, Never>?

init(continuation: UnsafeContinuation<TurmsNotification, any Error>, timeoutTask: Task<Void, Never>?) {
self.continuation = continuation
self.timeoutTask = timeoutTask
}
}

class DriverMessageService: BaseService {
private let requestTimeout: TimeInterval
private let minRequestInterval: TimeInterval
private var notificationListeners: [(TurmsNotification) -> Void] = []
private var requestMap: [Int64: Resolver<TurmsNotification>] = [:]
private var requestIdToContext: [Int64: RequestContext] = [:]
private var lastRequestDate = Date(timeIntervalSince1970: 0)
private let requestLock = Lock()

init(stateStore: StateStore, requestTimeout: TimeInterval? = nil, minRequestInterval: TimeInterval? = nil) {
self.requestTimeout = requestTimeout ?? 60
Expand All @@ -28,64 +38,88 @@ class DriverMessageService: BaseService {

// Request and notification

func sendRequest(_ populator: (inout TurmsRequest) -> Void) -> Promise<TurmsNotification> {
func sendRequest(_ populator: (inout TurmsRequest) -> Void) async throws -> TurmsNotification {
var request = TurmsRequest()
populator(&request)
return sendRequest(&request)
return try await sendRequest(&request)
}

func sendRequest(_ request: inout TurmsRequest) -> Promise<TurmsNotification> {
return Promise { seal in
func sendRequest(_ request: inout TurmsRequest) async throws -> TurmsNotification {
return try await withUnsafeThrowingContinuation { continuation in
if case .createSessionRequest = request.kind {
if stateStore.isSessionOpen {
return seal.reject(ResponseError(code: .clientSessionAlreadyEstablished))
return continuation.resume(throwing: ResponseError(code: .clientSessionAlreadyEstablished))
}
} else if !stateStore.isConnected || !stateStore.isSessionOpen {
return seal.reject(ResponseError(code: .clientSessionHasBeenClosed))
return continuation.resume(throwing: ResponseError(code: .clientSessionHasBeenClosed))
}
guard let tcp = stateStore.tcp else {
return continuation.resume(throwing: ResponseError(code: .clientSessionHasBeenClosed))
}
let now = Date()
let difference = now.timeIntervalSince1970 - lastRequestDate.timeIntervalSince1970
let isFrequent = minRequestInterval > 0 && difference <= minRequestInterval
if isFrequent {
return seal.reject(ResponseError(code: .clientRequestsTooFrequent))
return continuation.resume(throwing: ResponseError(code: .clientRequestsTooFrequent))
}
request.requestID = generateRandomId()
if requestTimeout > 0 {
after(.seconds(Int(requestTimeout))).done {
seal.reject(ResponseError(code: .requestTimeout))
requestLock.locked {
let requestId = generateRandomId()
request.requestID = requestId
let data: Data
do {
data = try request.serializedData()
} catch {
return continuation.resume(throwing: ResponseError(code: .invalidRequest, reason: "Failed to serialize the request: \(request)", cause: error))
}
var timeoutTask: Task<Void, Never>?
if requestTimeout > 0 {
timeoutTask = Task {
do {
try await Task.sleep(nanoseconds: UInt64(requestTimeout * 1_000_000_000))
requestLock.locked {
requestIdToContext.removeValue(forKey: requestId)?.continuation.resume(throwing: ResponseError(code: .requestTimeout))
}
} catch {}
}
}
requestIdToContext.updateValue(RequestContext(continuation: continuation, timeoutTask: timeoutTask), forKey: request.requestID)
stateStore.lastRequestDate = now
Task {
do {
try await tcp.writeVarIntLengthAndBytes(data)
} catch {
requestLock.locked {
requestIdToContext.removeValue(forKey: requestId)?.continuation.resume(throwing: ResponseError(code: .networkError, reason: "Failed to write", cause: error))
}
}
}
}
let data: Data
do {
data = try request.serializedData()
} catch {
seal.reject(ResponseError(code: .invalidRequest, reason: "Failed to serialize the request: \(request)", cause: error))
return
}
requestMap.updateValue(seal, forKey: request.requestID)
stateStore.lastRequestDate = now
stateStore.tcp!.writeVarIntLengthAndBytes(data)
}
}

func didReceiveNotification(_ notification: TurmsNotification) {
let isResponse = !notification.hasRelayedRequest && notification.hasRequestID
if isResponse {
let requestId = notification.requestID
let handler = requestMap[requestId]
if notification.hasCode {
let code = Int(notification.code)
if ResponseStatusCode.isSuccessCode(code) {
handler?.fulfill(notification)
} else {
if notification.hasReason {
handler?.reject(ResponseError(code: code, reason: notification.reason))
requestLock.locked {
if let context = requestIdToContext.removeValue(forKey: requestId) {
context.timeoutTask?.cancel()
let continuation = context.continuation
if notification.hasCode {
let code = Int(notification.code)
if ResponseStatusCode.isSuccessCode(code) {
continuation.resume(returning: notification)
} else {
if notification.hasReason {
continuation.resume(throwing: ResponseError(code: code, reason: notification.reason))
} else {
continuation.resume(throwing: ResponseError(code: code))
}
}
} else {
handler?.reject(ResponseError(code: code))
continuation.resume(throwing: ResponseError(code: ResponseStatusCode.invalidNotification, reason: "The code is missing"))
}
}
} else {
handler?.reject(ResponseError(code: ResponseStatusCode.invalidNotification, reason: "The code is missing"))
}
}
notifyNotificationListener(notification)
Expand All @@ -95,22 +129,23 @@ class DriverMessageService: BaseService {
var id: Int64
repeat {
id = Int64.random(in: 1 ... Int64.max)
} while requestMap.keys.contains(id)
} while requestIdToContext.keys.contains(id)
return id
}

private func rejectRequests(_ e: ResponseError) {
repeat {
requestMap.popFirst()?.value.reject(e)
} while !requestMap.isEmpty
requestLock.locked {
repeat {
requestIdToContext.popFirst()?.value.continuation.resume(throwing: e)
} while !requestIdToContext.isEmpty
}
}

override func close() -> Promise<Void> {
override func close() async {
onDisconnected()
return Promise.value(())
}

override func onDisconnected(_ error: Error? = nil) {
rejectRequests(ResponseError(code: .clientSessionHasBeenClosed, cause: error))
}
}
}
Loading

0 comments on commit ebaad9f

Please sign in to comment.