From 1d905f6831eaaa43e83cb8fd02b6d5ce8b3ac2c5 Mon Sep 17 00:00:00 2001 From: Dave Sperling Date: Wed, 14 Mar 2018 09:17:31 -0700 Subject: [PATCH] Permit reuse of connection within result callback (#45) - terminates the first query fetcher when a second query is made using the same connection from within a result callback, to allow the inner query fetcher to function properly - resolves #23 and #24 --- .../PostgreSQLConnection.swift | 91 +++++++++++++++++-- .../PostgreSQLResultFetcher.swift | 8 +- Sources/SwiftKueryPostgreSQL/Utils.swift | 11 ++- .../TestSelect.swift | 62 +++++++++++++ 4 files changed, 154 insertions(+), 18 deletions(-) diff --git a/Sources/SwiftKueryPostgreSQL/PostgreSQLConnection.swift b/Sources/SwiftKueryPostgreSQL/PostgreSQLConnection.swift index ab7d10c..cd9f560 100644 --- a/Sources/SwiftKueryPostgreSQL/PostgreSQLConnection.swift +++ b/Sources/SwiftKueryPostgreSQL/PostgreSQLConnection.swift @@ -17,18 +17,27 @@ import SwiftKuery import CLibpq +import Dispatch import Foundation +enum ConnectionState { + case idle, runningQuery, fetchingResultSet +} + // MARK: PostgreSQLConnection /// An implementation of `SwiftKuery.Connection` protocol for PostgreSQL. /// Please see [PostgreSQL manual](https://www.postgresql.org/docs/8.0/static/libpq-exec.html) for details. public class PostgreSQLConnection: Connection { - private var connection: OpaquePointer? + var connection: OpaquePointer? private var connectionParameters: String = "" private var inTransaction = false + private var state: ConnectionState = .idle + private var stateLock = DispatchSemaphore(value: 1) + private weak var currentResultFetcher: PostgreSQLResultFetcher? + private var preparedStatements = Set() /// An indication whether there is a connection to the database. @@ -286,16 +295,21 @@ public class PostgreSQLConnection: Connection { } private func prepareStatement(name: String, for query: String) -> String? { + if let error = setUpForRunningQuery() { + return error + } let result = PQprepare(connection, name, query, 0, nil) let status = PQresultStatus(result) if status != PGRES_COMMAND_OK { - var errorMessage = "Failed to create prepared statement." - if let error = String(validatingUTF8: PQerrorMessage(connection)) { - errorMessage += " Error: \(error)." - } - PQclear(result) - return errorMessage + setState(.idle) + var errorMessage = "Failed to create prepared statement." + if let error = String(validatingUTF8: PQerrorMessage(connection)) { + errorMessage += " Error: \(error)." + } + PQclear(result) + return errorMessage } + setState(.idle) PQclear(result) preparedStatements.insert(name) return nil @@ -342,6 +356,11 @@ public class PostgreSQLConnection: Connection { return } + if let error = setUpForRunningQuery() { + onCompletion(.error(QueryError.connection(error))) + return + } + var parameterPointers = [UnsafeMutablePointer?]() var parameterData = [UnsafePointer?]() // At the moment we only create string parameters. Binary parameters should be added. @@ -391,6 +410,7 @@ public class PostgreSQLConnection: Connection { private func processQueryResult(query: String, onCompletion: @escaping ((QueryResult) -> ())) { guard let result = PQgetResult(connection) else { + setState(.idle) var errorMessage = "No result returned for query: \(query)." if let error = String(validatingUTF8: PQerrorMessage(connection)) { errorMessage += " Error: \(error)." @@ -403,16 +423,18 @@ public class PostgreSQLConnection: Connection { if status == PGRES_COMMAND_OK || status == PGRES_TUPLES_OK { // Since we set the single row mode, PGRES_TUPLES_OK means the result is empty, i.e. there are // no rows to return. - clearResult(result, connection: connection) + clearResult(result, connection: self) onCompletion(.successNoData) } else if status == PGRES_SINGLE_TUPLE { - let resultFetcher = PostgreSQLResultFetcher(queryResult: result, connection: connection) + let resultFetcher = PostgreSQLResultFetcher(queryResult: result, connection: self) + setState(.fetchingResultSet) + currentResultFetcher = resultFetcher onCompletion(.resultSet(ResultSet(resultFetcher))) } else { let errorMessage = String(validatingUTF8: PQresultErrorMessage(result)) ?? "Unknown" - clearResult(result, connection: connection) + clearResult(result, connection: self) onCompletion(.error(QueryError.databaseError("Query execution error:\n" + errorMessage + " For query: " + query))) } } @@ -475,6 +497,11 @@ public class PostgreSQLConnection: Connection { return } + if let error = setUpForRunningQuery() { + onCompletion(.error(QueryError.connection(error))) + return + } + let result = PQexec(connection, command) let status = PQresultStatus(result) if status != PGRES_COMMAND_OK { @@ -484,6 +511,7 @@ public class PostgreSQLConnection: Connection { } PQclear(result) + setState(.idle) onCompletion(.error(QueryError.databaseError(message))) return } @@ -493,6 +521,7 @@ public class PostgreSQLConnection: Connection { } PQclear(result) + setState(.idle) onCompletion(.successNoData) } @@ -513,4 +542,46 @@ public class PostgreSQLConnection: Connection { } return postgresQuery } + + private func lockStateLock() { + _ = stateLock.wait(timeout: DispatchTime.distantFuture) + } + + private func unlockStateLock() { + stateLock.signal() + } + + func setState(_ newState: ConnectionState) { + lockStateLock() + if state == .fetchingResultSet { + currentResultFetcher = nil + } + state = newState + unlockStateLock() + } + + func setUpForRunningQuery() -> String? { + lockStateLock() + + switch state { + case .runningQuery: + unlockStateLock() + return "The connection is in the middle of running a query" + + case .fetchingResultSet: + currentResultFetcher?.hasMoreRows = false + unlockStateLock() + clearResult(nil, connection: self) + lockStateLock() + + case .idle: + break + } + + state = .runningQuery + + unlockStateLock() + + return nil + } } diff --git a/Sources/SwiftKueryPostgreSQL/PostgreSQLResultFetcher.swift b/Sources/SwiftKueryPostgreSQL/PostgreSQLResultFetcher.swift index 118d816..6e7d3c2 100644 --- a/Sources/SwiftKueryPostgreSQL/PostgreSQLResultFetcher.swift +++ b/Sources/SwiftKueryPostgreSQL/PostgreSQLResultFetcher.swift @@ -25,10 +25,10 @@ import Foundation public class PostgreSQLResultFetcher: ResultFetcher { private let titles: [String] private var row: [Any?]? - private var connection: OpaquePointer? - private var hasMoreRows = true + private var connection: PostgreSQLConnection + var hasMoreRows = true - init(queryResult: OpaquePointer, connection: OpaquePointer?) { + init(queryResult: OpaquePointer, connection: PostgreSQLConnection) { self.connection = connection let columns = PQnfields(queryResult) @@ -53,7 +53,7 @@ public class PostgreSQLResultFetcher: ResultFetcher { return nil } - guard let queryResult = PQgetResult(connection) else { + guard let queryResult = PQgetResult(connection.connection) else { // We are not supposed to get here, because we clear the result if we get PGRES_TUPLES_OK. hasMoreRows = false return nil diff --git a/Sources/SwiftKueryPostgreSQL/Utils.swift b/Sources/SwiftKueryPostgreSQL/Utils.swift index dd011e0..bcc74f6 100644 --- a/Sources/SwiftKueryPostgreSQL/Utils.swift +++ b/Sources/SwiftKueryPostgreSQL/Utils.swift @@ -17,13 +17,16 @@ import CLibpq import Foundation -func clearResult(_ lastResult: OpaquePointer, connection: OpaquePointer?) { - PQclear(lastResult) - var result = PQgetResult(connection) +func clearResult(_ lastResult: OpaquePointer?, connection: PostgreSQLConnection) { + if let lastResult = lastResult { + PQclear(lastResult) + } + var result = PQgetResult(connection.connection) while result != nil { PQclear(result) - result = PQgetResult(connection) + result = PQgetResult(connection.connection) } + connection.setState(.idle) } diff --git a/Tests/SwiftKueryPostgreSQLTests/TestSelect.swift b/Tests/SwiftKueryPostgreSQLTests/TestSelect.swift index e884470..66cab7a 100644 --- a/Tests/SwiftKueryPostgreSQLTests/TestSelect.swift +++ b/Tests/SwiftKueryPostgreSQLTests/TestSelect.swift @@ -26,11 +26,13 @@ let tableSelect = "tableSelectLinux" let tableSelect2 = "tableSelect2Linux" let tableSelect3 = "tableSelect3Linux" let tableSelectDate = "tableSelectDateLinux" +let tableConnectionState = "tableConnectionStateLinux" #else let tableSelect = "tableSelectOSX" let tableSelect2 = "tableSelect2OSX" let tableSelect3 = "tableSelect3OSX" let tableSelectDate = "tableSelectDateOSX" +let tableConnectionState = "tableConnectionStateOSX" #endif class TestSelect: XCTestCase { @@ -40,6 +42,7 @@ class TestSelect: XCTestCase { ("testSelect", testSelect), ("testSelectDate", testSelectDate), ("testSelectFromMany", testSelectFromMany), + ("testConnectionState", testConnectionState), ] } @@ -367,4 +370,63 @@ class TestSelect: XCTestCase { }) } + + class ConnectionStateTable: Table { + let a = Column("a", Varchar.self, length: 30) + let b = Column("b", Int32.self) + + let tableName = tableConnectionState + } + + func testConnectionState() { + let t = ConnectionStateTable() + + let pool = CommonUtils.sharedInstance.getConnectionPool() + performTest(asyncTasks: { expectation in + + guard let connection = pool.getConnection() else { + XCTFail("Failed to get connection") + return + } + + cleanUp(table: t.tableName, connection: connection) { result in + + t.create(connection: connection) { result in + XCTAssertEqual(result.success, true, "CREATE TABLE failed") + XCTAssertNil(result.asError, "Error in CREATE TABLE: \(result.asError!)") + + let i = Insert(into: t, rows: [["apple", 1], ["apricot", 2], ["banana", 3], ["qiwi", -1], ["plum", -2], ["peach", -3]]) + executeQuery(query: i, connection: connection) { result, rows in + XCTAssertEqual(result.success, true, "INSERT failed") + + let s1 = Select(from: t).where(t.b > 0) + s1.execute(connection) { result1 in + let s2 = Select(from: t).where(t.b < 0) + s2.execute(connection) { result2 in + + XCTAssertEqual(result1.success, true, "SELECT 1 failed") + XCTAssertEqual(result2.success, true, "SELECT 2 failed") + + var rows: [[Any?]]? = nil + if let resultSet = result2.asResultSet { + rows = rowsAsArray(resultSet) + if let rows = rows { + for row in rows { + if let b = row[1] as? Int32 { + XCTAssertTrue(b < 0, "Bad result for SELECT") + } + else { + XCTFail("Wrong type in SELECT") + } + } + } + } + } + } + } + } + } + expectation.fulfill() + }) + } }