Skip to content

Commit b402400

Browse files
pvlugtermp911de
authored andcommitted
Handle early disconnects before SSL handshake.
[resolves #595][#596]
1 parent 5688111 commit b402400

File tree

4 files changed

+140
-44
lines changed

4 files changed

+140
-44
lines changed

src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java

+19-20
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ private ReactorNettyClient(Connection connection, ConnectionSettings settings) {
144144
Assert.requireNonNull(connection, "Connection must not be null");
145145
this.settings = Assert.requireNonNull(settings, "ConnectionSettings must not be null");
146146

147-
connection.addHandlerFirst(new EnsureSubscribersCompleteChannelHandler(this.requestSink));
147+
connection.addHandlerLast(new EnsureSubscribersCompleteChannelHandler(this.requestSink));
148148
connection.addHandlerLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE - 5, 1, 4, -4, 0));
149149
this.connection = connection;
150150
this.byteBufAllocator = connection.outbound().alloc();
@@ -392,43 +392,42 @@ public static Mono<ReactorNettyClient> connect(SocketAddress socketAddress, Conn
392392
tcpClient = tcpClient.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, settings.getConnectTimeoutMs());
393393
}
394394

395-
return tcpClient.connect().flatMap(it -> {
396-
397-
ChannelPipeline pipeline = it.channel().pipeline();
395+
return tcpClient.doOnChannelInit((observer, channel, remoteAddress) -> {
396+
ChannelPipeline pipeline = channel.pipeline();
398397

399398
InternalLogger logger = InternalLoggerFactory.getInstance(ReactorNettyClient.class);
400399
if (logger.isTraceEnabled()) {
401400
pipeline.addFirst(LoggingHandler.class.getSimpleName(),
402401
new LoggingHandler(ReactorNettyClient.class, LogLevel.TRACE));
403402
}
404403

405-
return registerSslHandler(settings.getSslConfig(), it).thenReturn(new ReactorNettyClient(it, settings));
406-
});
404+
registerSslHandler(settings.getSslConfig(), channel);
405+
}).connect().flatMap(it ->
406+
getSslHandshake(it.channel()).thenReturn(new ReactorNettyClient(it, settings))
407+
);
407408
}
408409

409-
private static Mono<? extends Void> registerSslHandler(SSLConfig sslConfig, Connection it) {
410-
410+
private static void registerSslHandler(SSLConfig sslConfig, Channel channel) {
411411
try {
412412
if (sslConfig.getSslMode().startSsl()) {
413413

414-
return Mono.defer(() -> {
415-
AbstractPostgresSSLHandlerAdapter sslAdapter;
416-
if (sslConfig.getSslMode() == SSLMode.TUNNEL) {
417-
sslAdapter = new SSLTunnelHandlerAdapter(it.outbound().alloc(), sslConfig);
418-
} else {
419-
sslAdapter = new SSLSessionHandlerAdapter(it.outbound().alloc(), sslConfig);
420-
}
421-
422-
it.addHandlerFirst(sslAdapter);
423-
return sslAdapter.getHandshake();
414+
AbstractPostgresSSLHandlerAdapter sslAdapter;
415+
if (sslConfig.getSslMode() == SSLMode.TUNNEL) {
416+
sslAdapter = new SSLTunnelHandlerAdapter(channel.alloc(), sslConfig);
417+
} else {
418+
sslAdapter = new SSLSessionHandlerAdapter(channel.alloc(), sslConfig);
419+
}
424420

425-
}).subscribeOn(Schedulers.boundedElastic());
421+
channel.pipeline().addFirst(sslAdapter);
426422
}
427423
} catch (Throwable e) {
428424
throw new RuntimeException(e);
429425
}
426+
}
430427

431-
return Mono.empty();
428+
private static Mono<Void> getSslHandshake(Channel channel) {
429+
AbstractPostgresSSLHandlerAdapter sslAdapter = channel.pipeline().get(AbstractPostgresSSLHandlerAdapter.class);
430+
return (sslAdapter != null) ? sslAdapter.getHandshake() : Mono.empty();
432431
}
433432

434433
@Override

src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java

+38-21
Original file line numberDiff line numberDiff line change
@@ -33,35 +33,54 @@ final class SSLSessionHandlerAdapter extends AbstractPostgresSSLHandlerAdapter {
3333

3434
private final SSLConfig sslConfig;
3535

36+
private boolean negotiating = true;
37+
3638
SSLSessionHandlerAdapter(ByteBufAllocator alloc, SSLConfig sslConfig) {
3739
super(alloc, sslConfig);
3840
this.alloc = alloc;
3941
this.sslConfig = sslConfig;
4042
}
4143

4244
@Override
43-
public void handlerAdded(ChannelHandlerContext ctx) {
44-
Mono.from(SSLRequest.INSTANCE.encode(this.alloc)).subscribe(ctx::writeAndFlush);
45+
public void channelActive(ChannelHandlerContext ctx) throws Exception {
46+
if (negotiating) {
47+
Mono.from(SSLRequest.INSTANCE.encode(this.alloc)).subscribe(ctx::writeAndFlush);
48+
}
49+
super.channelActive(ctx);
50+
}
51+
52+
@Override
53+
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
54+
if (negotiating) {
55+
// If we receive channel inactive before negotiated, then the inbound has closed early.
56+
PostgresqlSslException e = new PostgresqlSslException("Connection closed during SSL negotiation");
57+
completeHandshakeExceptionally(e);
58+
}
59+
super.channelInactive(ctx);
4560
}
4661

4762
@Override
48-
public void channelRead(ChannelHandlerContext ctx, Object msg) {
49-
ByteBuf buf = (ByteBuf) msg;
50-
char response = (char) buf.readByte();
51-
try {
52-
switch (response) {
53-
case 'S':
54-
processSslEnabled(ctx, buf);
55-
break;
56-
case 'N':
57-
processSslDisabled();
58-
break;
59-
default:
60-
buf.release();
61-
throw new IllegalStateException("Unknown SSLResponse from server: '" + response + "'");
63+
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
64+
if (negotiating) {
65+
ByteBuf buf = (ByteBuf) msg;
66+
char response = (char) buf.readByte();
67+
try {
68+
switch (response) {
69+
case 'S':
70+
processSslEnabled(ctx, buf);
71+
break;
72+
case 'N':
73+
processSslDisabled();
74+
break;
75+
default:
76+
throw new IllegalStateException("Unknown SSLResponse from server: '" + response + "'");
77+
}
78+
} finally {
79+
buf.release();
80+
negotiating = false;
6281
}
63-
} finally {
64-
buf.release();
82+
} else {
83+
super.channelRead(ctx, msg);
6584
}
6685
}
6786

@@ -82,9 +101,7 @@ private void processSslEnabled(ChannelHandlerContext ctx, ByteBuf msg) {
82101
completeHandshakeExceptionally(e);
83102
return;
84103
}
85-
ctx.channel().pipeline()
86-
.addFirst(this.getSslHandler())
87-
.remove(this);
104+
ctx.channel().pipeline().addFirst(this.getSslHandler());
88105
ctx.fireChannelRead(msg.retain());
89106
}
90107

src/main/java/io/r2dbc/postgresql/client/SSLTunnelHandlerAdapter.java

+1-3
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ public void handlerAdded(ChannelHandlerContext ctx) {
4040
completeHandshakeExceptionally(e);
4141
return;
4242
}
43-
ctx.channel().pipeline()
44-
.addFirst(this.getSslHandler())
45-
.remove(this);
43+
ctx.channel().pipeline().addFirst(this.getSslHandler());
4644
}
4745

4846
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright 2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.r2dbc.postgresql.client;
18+
19+
import io.r2dbc.postgresql.PostgresqlConnectionConfiguration;
20+
import io.r2dbc.postgresql.PostgresqlConnectionFactory;
21+
import io.r2dbc.postgresql.api.PostgresqlException;
22+
import org.junit.jupiter.api.Test;
23+
import reactor.netty.DisposableChannel;
24+
import reactor.netty.DisposableServer;
25+
import reactor.netty.tcp.TcpServer;
26+
import reactor.test.StepVerifier;
27+
28+
import java.nio.channels.ClosedChannelException;
29+
import java.util.function.Consumer;
30+
31+
import static org.assertj.core.api.Assertions.assertThat;
32+
33+
public class DowntimeIntegrationTests {
34+
35+
// Simulate server downtime, where connections are accepted and then closed immediately
36+
static DisposableServer newServer() {
37+
return TcpServer.create()
38+
.doOnConnection(DisposableChannel::dispose)
39+
.bindNow();
40+
}
41+
42+
static PostgresqlConnectionFactory newConnectionFactory(DisposableServer server, SSLMode sslMode) {
43+
return new PostgresqlConnectionFactory(
44+
PostgresqlConnectionConfiguration.builder()
45+
.host(server.host())
46+
.port(server.port())
47+
.username("test")
48+
.sslMode(sslMode)
49+
.build());
50+
}
51+
52+
static void verifyError(SSLMode sslMode, Consumer<Throwable> assertions) {
53+
DisposableServer server = newServer();
54+
PostgresqlConnectionFactory connectionFactory = newConnectionFactory(server, sslMode);
55+
connectionFactory.create().as(StepVerifier::create).verifyErrorSatisfies(assertions);
56+
server.disposeNow();
57+
}
58+
59+
@Test
60+
void failSslHandshakeIfInboundClosed() {
61+
verifyError(SSLMode.REQUIRE, error ->
62+
assertThat(error)
63+
.isInstanceOf(AbstractPostgresSSLHandlerAdapter.PostgresqlSslException.class)
64+
.hasMessage("Connection closed during SSL negotiation"));
65+
}
66+
67+
@Test
68+
void failSslTunnelIfInboundClosed() {
69+
verifyError(SSLMode.TUNNEL, error -> {
70+
assertThat(error)
71+
.isInstanceOf(PostgresqlException.class)
72+
.cause()
73+
.isInstanceOf(ClosedChannelException.class);
74+
75+
assertThat(error.getCause().getSuppressed().length).isOne();
76+
77+
assertThat(error.getCause().getSuppressed()[0])
78+
.hasMessage("Connection closed while SSL/TLS handshake was in progress");
79+
});
80+
}
81+
82+
}

0 commit comments

Comments
 (0)