diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index ad8a4bf1..e9e947ef 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -307,6 +307,28 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { return try await closure(connection) } + + /// Lease a connection for the provided `closure`'s lifetime. + /// A transation starts with call to withConnection + /// A transaction should end with a call to COMMIT or ROLLBACK + /// COMMIT is called upon successful completion and ROLLBACK is called should any steps fail + /// + /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture + /// the provided `PostgresConnection`. + /// - Returns: The closure's return value. + public func withTransaction(_ process: (PostgresConnection) async throws -> Result) async throws -> Result { + try await withConnection { connection in + try await connection.query("BEGIN;", logger: self.backgroundLogger) + do { + let value = try await process(connection) + try await connection.query("COMMIT;", logger: self.backgroundLogger) + return value + } catch { + try await connection.query("ROLLBACK;", logger: self.backgroundLogger) + throw error + } + } + } /// Run a query on the Postgres server the client is connected to. /// diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index 579c92cd..167ba298 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -42,6 +42,110 @@ final class PostgresClientTests: XCTestCase { taskGroup.cancelAll() } } + + func testTransaction() async throws { + var mlogger = Logger(label: "test") + mlogger.logLevel = .debug + let logger = mlogger + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8) + self.addTeardownBlock { + try await eventLoopGroup.shutdownGracefully() + } + + let tableName = "test_client_transactions" + + let clientConfig = PostgresClient.Configuration.makeTestConfiguration() + let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) + + do { + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() + } + + try await client.query( + """ + CREATE TABLE IF NOT EXISTS "\(unescaped: tableName)" ( + id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + uuid UUID NOT NULL + ); + """, + logger: logger + ) + + let iterations = 1000 + + for _ in 0..