Skip to content

Commit

Permalink
Permit reuse of connection within result callback (#45)
Browse files Browse the repository at this point in the history
- terminates the first query fetcher when a second query is made using the same connection from within a result callback, to allow the inner query fetcher to function properly
- resolves #23 and #24
  • Loading branch information
dsperling authored and djones6 committed Mar 14, 2018
1 parent 3be53cc commit 1d905f6
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 18 deletions.
91 changes: 81 additions & 10 deletions Sources/SwiftKueryPostgreSQL/PostgreSQLConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,27 @@
import SwiftKuery
import CLibpq

import Dispatch
import Foundation

enum ConnectionState {
case idle, runningQuery, fetchingResultSet
}

// MARK: PostgreSQLConnection

/// An implementation of `SwiftKuery.Connection` protocol for PostgreSQL.
/// Please see [PostgreSQL manual](https://www.postgresql.org/docs/8.0/static/libpq-exec.html) for details.
public class PostgreSQLConnection: Connection {

private var connection: OpaquePointer?
var connection: OpaquePointer?
private var connectionParameters: String = ""
private var inTransaction = false

private var state: ConnectionState = .idle
private var stateLock = DispatchSemaphore(value: 1)
private weak var currentResultFetcher: PostgreSQLResultFetcher?

private var preparedStatements = Set<String>()

/// An indication whether there is a connection to the database.
Expand Down Expand Up @@ -286,16 +295,21 @@ public class PostgreSQLConnection: Connection {
}

private func prepareStatement(name: String, for query: String) -> String? {
if let error = setUpForRunningQuery() {
return error
}
let result = PQprepare(connection, name, query, 0, nil)
let status = PQresultStatus(result)
if status != PGRES_COMMAND_OK {
var errorMessage = "Failed to create prepared statement."
if let error = String(validatingUTF8: PQerrorMessage(connection)) {
errorMessage += " Error: \(error)."
}
PQclear(result)
return errorMessage
setState(.idle)
var errorMessage = "Failed to create prepared statement."
if let error = String(validatingUTF8: PQerrorMessage(connection)) {
errorMessage += " Error: \(error)."
}
PQclear(result)
return errorMessage
}
setState(.idle)
PQclear(result)
preparedStatements.insert(name)
return nil
Expand Down Expand Up @@ -342,6 +356,11 @@ public class PostgreSQLConnection: Connection {
return
}

if let error = setUpForRunningQuery() {
onCompletion(.error(QueryError.connection(error)))
return
}

var parameterPointers = [UnsafeMutablePointer<Int8>?]()
var parameterData = [UnsafePointer<Int8>?]()
// At the moment we only create string parameters. Binary parameters should be added.
Expand Down Expand Up @@ -391,6 +410,7 @@ public class PostgreSQLConnection: Connection {

private func processQueryResult(query: String, onCompletion: @escaping ((QueryResult) -> ())) {
guard let result = PQgetResult(connection) else {
setState(.idle)
var errorMessage = "No result returned for query: \(query)."
if let error = String(validatingUTF8: PQerrorMessage(connection)) {
errorMessage += " Error: \(error)."
Expand All @@ -403,16 +423,18 @@ public class PostgreSQLConnection: Connection {
if status == PGRES_COMMAND_OK || status == PGRES_TUPLES_OK {
// Since we set the single row mode, PGRES_TUPLES_OK means the result is empty, i.e. there are
// no rows to return.
clearResult(result, connection: connection)
clearResult(result, connection: self)
onCompletion(.successNoData)
}
else if status == PGRES_SINGLE_TUPLE {
let resultFetcher = PostgreSQLResultFetcher(queryResult: result, connection: connection)
let resultFetcher = PostgreSQLResultFetcher(queryResult: result, connection: self)
setState(.fetchingResultSet)
currentResultFetcher = resultFetcher
onCompletion(.resultSet(ResultSet(resultFetcher)))
}
else {
let errorMessage = String(validatingUTF8: PQresultErrorMessage(result)) ?? "Unknown"
clearResult(result, connection: connection)
clearResult(result, connection: self)
onCompletion(.error(QueryError.databaseError("Query execution error:\n" + errorMessage + " For query: " + query)))
}
}
Expand Down Expand Up @@ -475,6 +497,11 @@ public class PostgreSQLConnection: Connection {
return
}

if let error = setUpForRunningQuery() {
onCompletion(.error(QueryError.connection(error)))
return
}

let result = PQexec(connection, command)
let status = PQresultStatus(result)
if status != PGRES_COMMAND_OK {
Expand All @@ -484,6 +511,7 @@ public class PostgreSQLConnection: Connection {
}

PQclear(result)
setState(.idle)
onCompletion(.error(QueryError.databaseError(message)))
return
}
Expand All @@ -493,6 +521,7 @@ public class PostgreSQLConnection: Connection {
}

PQclear(result)
setState(.idle)
onCompletion(.successNoData)
}

Expand All @@ -513,4 +542,46 @@ public class PostgreSQLConnection: Connection {
}
return postgresQuery
}

private func lockStateLock() {
_ = stateLock.wait(timeout: DispatchTime.distantFuture)
}

private func unlockStateLock() {
stateLock.signal()
}

func setState(_ newState: ConnectionState) {
lockStateLock()
if state == .fetchingResultSet {
currentResultFetcher = nil
}
state = newState
unlockStateLock()
}

func setUpForRunningQuery() -> String? {
lockStateLock()

switch state {
case .runningQuery:
unlockStateLock()
return "The connection is in the middle of running a query"

case .fetchingResultSet:
currentResultFetcher?.hasMoreRows = false
unlockStateLock()
clearResult(nil, connection: self)
lockStateLock()

case .idle:
break
}

state = .runningQuery

unlockStateLock()

return nil
}
}
8 changes: 4 additions & 4 deletions Sources/SwiftKueryPostgreSQL/PostgreSQLResultFetcher.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ import Foundation
public class PostgreSQLResultFetcher: ResultFetcher {
private let titles: [String]
private var row: [Any?]?
private var connection: OpaquePointer?
private var hasMoreRows = true
private var connection: PostgreSQLConnection
var hasMoreRows = true

init(queryResult: OpaquePointer, connection: OpaquePointer?) {
init(queryResult: OpaquePointer, connection: PostgreSQLConnection) {
self.connection = connection

let columns = PQnfields(queryResult)
Expand All @@ -53,7 +53,7 @@ public class PostgreSQLResultFetcher: ResultFetcher {
return nil
}

guard let queryResult = PQgetResult(connection) else {
guard let queryResult = PQgetResult(connection.connection) else {
// We are not supposed to get here, because we clear the result if we get PGRES_TUPLES_OK.
hasMoreRows = false
return nil
Expand Down
11 changes: 7 additions & 4 deletions Sources/SwiftKueryPostgreSQL/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
import CLibpq
import Foundation

func clearResult(_ lastResult: OpaquePointer, connection: OpaquePointer?) {
PQclear(lastResult)
var result = PQgetResult(connection)
func clearResult(_ lastResult: OpaquePointer?, connection: PostgreSQLConnection) {
if let lastResult = lastResult {
PQclear(lastResult)
}
var result = PQgetResult(connection.connection)
while result != nil {
PQclear(result)
result = PQgetResult(connection)
result = PQgetResult(connection.connection)
}
connection.setState(.idle)
}


Expand Down
62 changes: 62 additions & 0 deletions Tests/SwiftKueryPostgreSQLTests/TestSelect.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ let tableSelect = "tableSelectLinux"
let tableSelect2 = "tableSelect2Linux"
let tableSelect3 = "tableSelect3Linux"
let tableSelectDate = "tableSelectDateLinux"
let tableConnectionState = "tableConnectionStateLinux"
#else
let tableSelect = "tableSelectOSX"
let tableSelect2 = "tableSelect2OSX"
let tableSelect3 = "tableSelect3OSX"
let tableSelectDate = "tableSelectDateOSX"
let tableConnectionState = "tableConnectionStateOSX"
#endif

class TestSelect: XCTestCase {
Expand All @@ -40,6 +42,7 @@ class TestSelect: XCTestCase {
("testSelect", testSelect),
("testSelectDate", testSelectDate),
("testSelectFromMany", testSelectFromMany),
("testConnectionState", testConnectionState),
]
}

Expand Down Expand Up @@ -367,4 +370,63 @@ class TestSelect: XCTestCase {
})
}


class ConnectionStateTable: Table {
let a = Column("a", Varchar.self, length: 30)
let b = Column("b", Int32.self)

let tableName = tableConnectionState
}

func testConnectionState() {
let t = ConnectionStateTable()

let pool = CommonUtils.sharedInstance.getConnectionPool()
performTest(asyncTasks: { expectation in

guard let connection = pool.getConnection() else {
XCTFail("Failed to get connection")
return
}

cleanUp(table: t.tableName, connection: connection) { result in

t.create(connection: connection) { result in
XCTAssertEqual(result.success, true, "CREATE TABLE failed")
XCTAssertNil(result.asError, "Error in CREATE TABLE: \(result.asError!)")

let i = Insert(into: t, rows: [["apple", 1], ["apricot", 2], ["banana", 3], ["qiwi", -1], ["plum", -2], ["peach", -3]])
executeQuery(query: i, connection: connection) { result, rows in
XCTAssertEqual(result.success, true, "INSERT failed")

let s1 = Select(from: t).where(t.b > 0)
s1.execute(connection) { result1 in
let s2 = Select(from: t).where(t.b < 0)
s2.execute(connection) { result2 in

XCTAssertEqual(result1.success, true, "SELECT 1 failed")
XCTAssertEqual(result2.success, true, "SELECT 2 failed")

var rows: [[Any?]]? = nil
if let resultSet = result2.asResultSet {
rows = rowsAsArray(resultSet)
if let rows = rows {
for row in rows {
if let b = row[1] as? Int32 {
XCTAssertTrue(b < 0, "Bad result for SELECT")
}
else {
XCTFail("Wrong type in SELECT")
}
}
}
}
}
}
}
}
}
expectation.fulfill()
})
}
}

0 comments on commit 1d905f6

Please sign in to comment.