Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(transaction): Adding withTransaction #519

Merged
merged 14 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions Sources/PostgresNIO/Pool/PostgresClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,28 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service {

return try await closure(connection)
}

/// Lease a connection for the provided `closure`'s lifetime.
/// A transation starts with call to withConnection
/// A transaction should end with a call to COMMIT or ROLLBACK
/// COMMIT is called upon successful completion and ROLLBACK is called should any steps fail
///
/// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture
/// the provided `PostgresConnection`.
/// - Returns: The closure's return value.
public func withTransaction<Result>(_ process: (PostgresConnection) async throws -> Result) async throws -> Result {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should capture calling file and line here. Then we could attach that info to the error that is thrown. We would wrap the thrown error in a PostgresTransactionError. We could also attach the Rollback error, if that happens. cc @gwynne WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me, yeah.

try await withConnection { connection in
try await connection.query("BEGIN;", logger: self.backgroundLogger)
do {
let value = try await process(connection)
try await connection.query("COMMIT;", logger: self.backgroundLogger)
return value
} catch {
try await connection.query("ROLLBACK;", logger: self.backgroundLogger)
throw error
}
}
}

/// Run a query on the Postgres server the client is connected to.
///
Expand Down
104 changes: 104 additions & 0 deletions Tests/IntegrationTests/PostgresClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,110 @@ final class PostgresClientTests: XCTestCase {
taskGroup.cancelAll()
}
}

func testTransaction() async throws {
var mlogger = Logger(label: "test")
mlogger.logLevel = .debug
let logger = mlogger
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8)
self.addTeardownBlock {
try await eventLoopGroup.shutdownGracefully()
}

let tableName = "test_client_transactions"

let clientConfig = PostgresClient.Configuration.makeTestConfiguration()
let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger)

do {
try await withThrowingTaskGroup(of: Void.self) { taskGroup in
taskGroup.addTask {
await client.run()
}

try await client.query(
"""
CREATE TABLE IF NOT EXISTS "\(unescaped: tableName)" (
id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
uuid UUID NOT NULL
);
""",
logger: logger
)

let iterations = 1000

for _ in 0..<iterations {
taskGroup.addTask {
let _ = try await client.withTransaction { transaction in
try await transaction.query(
"""
INSERT INTO "\(unescaped: tableName)" (uuid) VALUES (\(UUID()));
""",
logger: logger
)
}
}
}

for _ in 0..<iterations {
_ = await taskGroup.nextResult()!
}

let rows = try await client.query(#"SELECT COUNT(1)::INT AS table_size FROM "\#(unescaped: tableName)";"#, logger: logger).decode(Int.self)
for try await (count) in rows {
XCTAssertEqual(count, iterations)
}

/// Test roll back
taskGroup.addTask {

do {
let _ = try await client.withTransaction { transaction in
/// insert valid data
try await transaction.query(
"""
INSERT INTO "\(unescaped: tableName)" (uuid) VALUES (\(UUID()));
""",
logger: logger
)

/// insert invalid data
try await transaction.query(
"""
INSERT INTO "\(unescaped: tableName)" (uuid) VALUES (\(iterations));
""",
logger: logger
)
}
} catch {
XCTAssertNotNil(error)
guard let error = error as? PSQLError else { return XCTFail("Unexpected error type") }

XCTAssertEqual(error.code, .server)
XCTAssertEqual(error.serverInfo?[.severity], "ERROR")
}
}

let row = try await client.query(#"SELECT COUNT(1)::INT AS table_size FROM "\#(unescaped: tableName)";"#, logger: logger).decode(Int.self)

for try await (count) in row {
XCTAssertEqual(count, iterations)
}

try await client.query(
"""
DROP TABLE "\(unescaped: tableName)";
""",
logger: logger
)

taskGroup.cancelAll()
}
} catch {
XCTFail("Unexpected error: \(String(reflecting: error))")
}
}

func testApplicationNameIsForwardedCorrectly() async throws {
var mlogger = Logger(label: "test")
Expand Down
Loading