17
17
package io .r2dbc .pool ;
18
18
19
19
import io .r2dbc .spi .Connection ;
20
+ import io .r2dbc .spi .TransactionDefinition ;
20
21
import io .r2dbc .spi .ValidationDepth ;
21
22
import org .junit .jupiter .api .BeforeEach ;
22
23
import org .junit .jupiter .api .Test ;
28
29
import java .util .concurrent .atomic .AtomicInteger ;
29
30
30
31
import static org .assertj .core .api .Assertions .assertThat ;
32
+ import static org .mockito .ArgumentMatchers .any ;
31
33
import static org .mockito .Mockito .mock ;
32
34
import static org .mockito .Mockito .never ;
33
35
import static org .mockito .Mockito .reset ;
@@ -52,6 +54,7 @@ void setUp() {
52
54
when (pooledRefMock .poolable ()).thenReturn (connectionMock );
53
55
when (pooledRefMock .release ()).thenReturn (Mono .empty ());
54
56
when (connectionMock .beginTransaction ()).thenReturn (Mono .empty ());
57
+ when (connectionMock .beginTransaction (any ())).thenReturn (Mono .empty ());
55
58
when (connectionMock .close ()).thenReturn (Mono .empty ());
56
59
when (connectionMock .validate (ValidationDepth .LOCAL )).thenReturn (Mono .empty ());
57
60
}
@@ -71,6 +74,21 @@ void shouldRollbackUnfinishedTransaction() {
71
74
assertThat (wasCalled ).isTrue ();
72
75
}
73
76
77
+ @ Test
78
+ void shouldRollbackUnfinishedExtendedTransaction () {
79
+
80
+ AtomicBoolean wasCalled = new AtomicBoolean ();
81
+ when (connectionMock .rollbackTransaction ()).thenReturn (Mono .<Void >empty ().doOnSuccess (o -> wasCalled .set (true )));
82
+
83
+ PooledConnection connection = new PooledConnection (pooledRefMock );
84
+ connection .beginTransaction (mock (TransactionDefinition .class )).as (StepVerifier ::create ).verifyComplete ();
85
+
86
+ connection .close ().as (StepVerifier ::create ).verifyComplete ();
87
+
88
+ verify (connectionMock ).rollbackTransaction ();
89
+ assertThat (wasCalled ).isTrue ();
90
+ }
91
+
74
92
@ Test
75
93
void shouldPristineTransactionLeavesTransactionalStateAsIs () {
76
94
0 commit comments