Skip to content

Commit 8a66dc5

Browse files
authored
Close client connection when remote closes connection + testing (#686) (#687)
1 parent a5c10ab commit 8a66dc5

File tree

2 files changed

+147
-2
lines changed

2 files changed

+147
-2
lines changed

src/main/java/net/schmizz/sshj/common/StreamCopier.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,14 @@ public long copy()
145145
final double sizeKiB = count / 1024.0;
146146
log.debug(String.format("%1$,.1f KiB transferred in %2$,.1f seconds (%3$,.2f KiB/s)", sizeKiB, timeSeconds, (sizeKiB / timeSeconds)));
147147

148-
if (length != -1 && read == -1)
149-
throw new IOException("Encountered EOF, could not transfer " + length + " bytes");
148+
// Did we encounter EOF?
149+
if (read == -1) {
150+
// If InputStream was closed we should also close OutputStream
151+
out.close();
152+
153+
if (length != -1)
154+
throw new IOException("Encountered EOF, could not transfer " + length + " bytes");
155+
}
150156

151157
return count;
152158
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Copyright (C)2009 - SSHJ Contributors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.hierynomus.sshj.connection.channel.forwarded;
17+
18+
import com.hierynomus.sshj.test.HttpServer;
19+
import com.hierynomus.sshj.test.SshFixture;
20+
import com.hierynomus.sshj.test.util.FileUtil;
21+
import net.schmizz.sshj.SSHClient;
22+
import net.schmizz.sshj.connection.channel.direct.LocalPortForwarder;
23+
import net.schmizz.sshj.connection.channel.direct.Parameters;
24+
import org.apache.http.HttpResponse;
25+
import org.apache.http.client.HttpClient;
26+
import org.apache.http.client.methods.HttpGet;
27+
import org.apache.http.impl.client.HttpClientBuilder;
28+
import org.apache.sshd.server.forward.AcceptAllForwardingFilter;
29+
import org.junit.Assert;
30+
import org.junit.Before;
31+
import org.junit.Rule;
32+
import org.junit.Test;
33+
import org.slf4j.Logger;
34+
import org.slf4j.LoggerFactory;
35+
36+
import java.io.*;
37+
import java.net.InetSocketAddress;
38+
import java.net.ServerSocket;
39+
import java.net.Socket;
40+
41+
import static org.hamcrest.CoreMatchers.equalTo;
42+
import static org.junit.Assert.assertThat;
43+
44+
public class LocalPortForwarderTest {
45+
private static final Logger log = LoggerFactory.getLogger(LocalPortForwarderTest.class);
46+
47+
@Rule
48+
public SshFixture fixture = new SshFixture();
49+
50+
@Rule
51+
public HttpServer httpServer = new HttpServer();
52+
53+
@Before
54+
public void setUp() throws IOException {
55+
fixture.getServer().setForwardingFilter(new AcceptAllForwardingFilter());
56+
File file = httpServer.getDocRoot().newFile("index.html");
57+
FileUtil.writeToFile(file, "<html><head/><body><h1>Hi!</h1></body></html>");
58+
}
59+
60+
@Test
61+
public void shouldHaveWorkingHttpServer() throws IOException {
62+
// Just to check that we have a working http server...
63+
assertThat(httpGet("127.0.0.1", 8080), equalTo(200));
64+
}
65+
66+
@Test
67+
public void shouldHaveHttpServerThatClosesConnectionAfterResponse() throws IOException {
68+
// Just to check that the test server does close connections before we try through the forwarder...
69+
httpGetAndAssertConnectionClosedByServer(8080);
70+
}
71+
72+
@Test(timeout = 10_000)
73+
public void shouldCloseConnectionWhenRemoteServerClosesConnection() throws IOException {
74+
SSHClient sshClient = getFixtureClient();
75+
76+
ServerSocket serverSocket = new ServerSocket();
77+
serverSocket.setReuseAddress(true);
78+
serverSocket.bind(new InetSocketAddress("0.0.0.0", 12345));
79+
LocalPortForwarder localPortForwarder = sshClient.newLocalPortForwarder(new Parameters("0.0.0.0", 12345, "localhost", 8080), serverSocket);
80+
new Thread(() -> {
81+
try {
82+
localPortForwarder.listen();
83+
} catch (IOException e) {
84+
throw new RuntimeException(e);
85+
}
86+
}, "local port listener").start();
87+
88+
// Test once to prove that the local HTTP connection is closed when the remote HTTP connection is closed.
89+
httpGetAndAssertConnectionClosedByServer(12345);
90+
91+
// Test again to prove that the tunnel is still open, even after HTTP connection was closed.
92+
httpGetAndAssertConnectionClosedByServer(12345);
93+
}
94+
95+
public static void httpGetAndAssertConnectionClosedByServer(int port) throws IOException {
96+
System.out.println("HTTP GET to port: " + port);
97+
try (Socket socket = new Socket("localhost", port)) {
98+
// Send a basic HTTP GET
99+
// It returns 400 Bad Request because it's missing a bunch of info, but the HTTP response doesn't matter, we just want to test the connection closing.
100+
OutputStream outputStream = socket.getOutputStream();
101+
PrintWriter writer = new PrintWriter(outputStream);
102+
writer.println("GET / HTTP/1.1");
103+
writer.println("");
104+
writer.flush();
105+
106+
// Read the HTTP response
107+
InputStream inputStream = socket.getInputStream();
108+
InputStreamReader reader = new InputStreamReader(inputStream);
109+
int buf = -2;
110+
while (true) {
111+
buf = reader.read();
112+
System.out.print((char)buf);
113+
if (buf == -1) {
114+
break;
115+
}
116+
}
117+
118+
// Attempt to read more. If the server has closed the connection this will return -1
119+
int read = inputStream.read();
120+
121+
// Assert input stream was closed by server.
122+
Assert.assertEquals(-1, read);
123+
}
124+
}
125+
126+
private int httpGet(String server, int port) throws IOException {
127+
HttpClient client = HttpClientBuilder.create().build();
128+
String urlString = "http://" + server + ":" + port;
129+
log.info("Trying: GET " + urlString);
130+
HttpResponse execute = client.execute(new HttpGet(urlString));
131+
return execute.getStatusLine().getStatusCode();
132+
}
133+
134+
private SSHClient getFixtureClient() throws IOException {
135+
SSHClient sshClient = fixture.setupConnectedDefaultClient();
136+
sshClient.authPassword("jeroen", "jeroen");
137+
return sshClient;
138+
}
139+
}

0 commit comments

Comments
 (0)