Skip to content

Commit 47df71c

Browse files
committed
Implemented diffie-hellman-group-exchange Kex methods (Fixes #167)
1 parent e24ed6e commit 47df71c

File tree

10 files changed

+300
-20
lines changed

10 files changed

+300
-20
lines changed

src/main/java/net/schmizz/sshj/DefaultConfig.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import net.schmizz.sshj.transport.compression.NoneCompression;
3434
import net.schmizz.sshj.transport.kex.DHG1;
3535
import net.schmizz.sshj.transport.kex.DHG14;
36+
import net.schmizz.sshj.transport.kex.DHGexSHA1;
37+
import net.schmizz.sshj.transport.kex.DHGexSHA256;
3638
import net.schmizz.sshj.transport.mac.HMACMD5;
3739
import net.schmizz.sshj.transport.mac.HMACMD596;
3840
import net.schmizz.sshj.transport.mac.HMACSHA1;
@@ -98,9 +100,9 @@ public DefaultConfig() {
98100

99101
protected void initKeyExchangeFactories(boolean bouncyCastleRegistered) {
100102
if (bouncyCastleRegistered)
101-
setKeyExchangeFactories(new DHG14.Factory(), new DHG1.Factory());
103+
setKeyExchangeFactories(new DHG14.Factory(), new DHG1.Factory(), new DHGexSHA1.Factory(), new DHGexSHA256.Factory());
102104
else
103-
setKeyExchangeFactories(new DHG1.Factory());
105+
setKeyExchangeFactories(new DHG1.Factory(), new DHGexSHA1.Factory());
104106
}
105107

106108
protected void initRandomFactory(boolean bouncyCastleRegistered) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package net.schmizz.sshj.transport.digest;
2+
3+
/** SHA256 Digest. */
4+
public class SHA256 extends BaseDigest {
5+
6+
/** Named factory for SHA256 digest */
7+
public static class Factory
8+
implements net.schmizz.sshj.common.Factory.Named<Digest> {
9+
10+
@Override
11+
public Digest create() {
12+
return new SHA256();
13+
}
14+
15+
@Override
16+
public String getName() {
17+
return "sha256";
18+
}
19+
}
20+
21+
/** Create a new instance of a SHA256 digest */
22+
public SHA256() {
23+
super("SHA-256", 32);
24+
}
25+
26+
}

src/main/java/net/schmizz/sshj/transport/kex/AbstractDHG.java

+3-18
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,14 @@
3838
* Base class for DHG key exchange algorithms. Implementations will only have to configure the required data on the
3939
* {@link DH} class in the
4040
*/
41-
public abstract class AbstractDHG
41+
public abstract class AbstractDHG extends KeyExchangeBase
4242
implements KeyExchange {
4343

4444
private final Logger log = LoggerFactory.getLogger(getClass());
4545

46-
private Transport trans;
47-
4846
private final Digest sha1 = new SHA1();
4947
private final DH dh = new DH();
5048

51-
private String V_S;
52-
private String V_C;
53-
private byte[] I_S;
54-
private byte[] I_C;
55-
5649
private byte[] H;
5750
private PublicKey hostKey;
5851

@@ -79,11 +72,7 @@ public PublicKey getHostKey() {
7972
@Override
8073
public void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C)
8174
throws GeneralSecurityException, TransportException {
82-
this.trans = trans;
83-
this.V_S = V_S;
84-
this.V_C = V_C;
85-
this.I_S = Arrays.copyOf(I_S, I_S.length);
86-
this.I_C = Arrays.copyOf(I_C, I_C.length);
75+
super.init(trans, V_S, V_C, I_S, I_C);
8776
sha1.init();
8877
initDH(dh);
8978

@@ -112,11 +101,7 @@ public boolean next(Message msg, SSHPacket packet)
112101

113102
dh.computeK(f);
114103

115-
final Buffer.PlainBuffer buf = new Buffer.PlainBuffer()
116-
.putString(V_C)
117-
.putString(V_S)
118-
.putString(I_C)
119-
.putString(I_S)
104+
final Buffer.PlainBuffer buf = initializedBuffer()
120105
.putString(K_S)
121106
.putMPInt(dh.getE())
122107
.putMPInt(f)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package net.schmizz.sshj.transport.kex;
2+
3+
import net.schmizz.sshj.common.*;
4+
import net.schmizz.sshj.signature.Signature;
5+
import net.schmizz.sshj.transport.Transport;
6+
import net.schmizz.sshj.transport.TransportException;
7+
import net.schmizz.sshj.transport.digest.Digest;
8+
import org.slf4j.Logger;
9+
import org.slf4j.LoggerFactory;
10+
11+
import java.math.BigInteger;
12+
import java.security.GeneralSecurityException;
13+
import java.security.PublicKey;
14+
import java.util.Arrays;
15+
16+
public abstract class AbstractDHGex extends KeyExchangeBase {
17+
private final Logger log = LoggerFactory.getLogger(getClass());
18+
19+
private Digest digest;
20+
21+
private int minBits = 1024;
22+
private int maxBits = 8192;
23+
private int preferredBits = 2048;
24+
25+
private DH dh;
26+
private PublicKey hostKey;
27+
private byte[] H;
28+
29+
public AbstractDHGex(Digest digest) {
30+
this.digest = digest;
31+
}
32+
33+
@Override
34+
public void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C) throws GeneralSecurityException, TransportException {
35+
super.init(trans, V_S, V_C, I_S, I_C);
36+
dh = new DH();
37+
digest.init();
38+
39+
log.debug("Sending {}", Message.KEX_DH_GEX_REQUEST);
40+
trans.write(new SSHPacket(Message.KEX_DH_GEX_REQUEST).putUInt32(minBits).putUInt32(preferredBits).putUInt32(maxBits));
41+
}
42+
43+
@Override
44+
public byte[] getH() {
45+
return Arrays.copyOf(H, H.length);
46+
}
47+
48+
@Override
49+
public BigInteger getK() {
50+
return dh.getK();
51+
}
52+
53+
@Override
54+
public Digest getHash() {
55+
return digest;
56+
}
57+
58+
@Override
59+
public PublicKey getHostKey() {
60+
return hostKey;
61+
}
62+
63+
@Override
64+
public boolean next(Message msg, SSHPacket buffer) throws GeneralSecurityException, TransportException {
65+
log.debug("Got message {}", msg);
66+
try {
67+
switch (msg) {
68+
case KEXDH_31:
69+
return parseGexGroup(buffer);
70+
case KEX_DH_GEX_REPLY:
71+
return parseGexReply(buffer);
72+
}
73+
} catch (Buffer.BufferException be) {
74+
throw new TransportException(be);
75+
}
76+
throw new TransportException("Unexpected message " + msg);
77+
}
78+
79+
private boolean parseGexReply(SSHPacket buffer) throws Buffer.BufferException, GeneralSecurityException, TransportException {
80+
byte[] K_S = buffer.readBytes();
81+
BigInteger f = buffer.readMPInt();
82+
byte[] sig = buffer.readBytes();
83+
hostKey = new Buffer.PlainBuffer(K_S).readPublicKey();
84+
85+
dh.computeK(f);
86+
BigInteger k = dh.getK();
87+
88+
final Buffer.PlainBuffer buf = initializedBuffer()
89+
.putString(K_S)
90+
.putUInt32(minBits)
91+
.putUInt32(preferredBits)
92+
.putUInt32(maxBits)
93+
.putMPInt(dh.getP())
94+
.putMPInt(dh.getG())
95+
.putMPInt(dh.getE())
96+
.putMPInt(f)
97+
.putMPInt(k);
98+
digest.update(buf.array(), buf.rpos(), buf.available());
99+
H = digest.digest();
100+
Signature signature = Factory.Named.Util.create(trans.getConfig().getSignatureFactories(),
101+
KeyType.fromKey(hostKey).toString());
102+
signature.init(hostKey, null);
103+
signature.update(H, 0, H.length);
104+
if (!signature.verify(sig))
105+
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
106+
"KeyExchange signature verification failed");
107+
return true;
108+
109+
}
110+
111+
private boolean parseGexGroup(SSHPacket buffer) throws Buffer.BufferException, GeneralSecurityException, TransportException {
112+
BigInteger p = buffer.readMPInt();
113+
BigInteger g = buffer.readMPInt();
114+
int bitLength = p.bitLength();
115+
if (bitLength < minBits || bitLength > maxBits) {
116+
throw new GeneralSecurityException("Server generated gex p is out of range (" + bitLength + " bits)");
117+
}
118+
log.debug("Received server p bitlength {}", bitLength);
119+
dh.init(p, g);
120+
log.debug("Sending {}", Message.KEX_DH_GEX_INIT);
121+
trans.write(new SSHPacket(Message.KEX_DH_GEX_INIT).putMPInt(dh.getE()));
122+
return false;
123+
}
124+
}

src/main/java/net/schmizz/sshj/transport/kex/DH.java

+7
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,11 @@ public BigInteger getK() {
7373
return K;
7474
}
7575

76+
public BigInteger getP() {
77+
return p;
78+
}
79+
80+
public BigInteger getG() {
81+
return g;
82+
}
7683
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package net.schmizz.sshj.transport.kex;
2+
3+
import net.schmizz.sshj.transport.digest.SHA1;
4+
5+
public class DHGexSHA1 extends AbstractDHGex {
6+
7+
/** Named factory for DHGexSHA1 key exchange */
8+
public static class Factory
9+
implements net.schmizz.sshj.common.Factory.Named<KeyExchange> {
10+
11+
@Override
12+
public KeyExchange create() {
13+
return new DHGexSHA1();
14+
}
15+
16+
@Override
17+
public String getName() {
18+
return "diffie-hellman-group-exchange-sha1";
19+
}
20+
}
21+
22+
public DHGexSHA1() {
23+
super(new SHA1());
24+
}
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package net.schmizz.sshj.transport.kex;
2+
3+
import net.schmizz.sshj.transport.digest.SHA256;
4+
5+
public class DHGexSHA256 extends AbstractDHGex {
6+
7+
/** Named factory for DHGexSHA256 key exchange */
8+
public static class Factory
9+
implements net.schmizz.sshj.common.Factory.Named<KeyExchange> {
10+
11+
@Override
12+
public KeyExchange create() {
13+
return new DHGexSHA256();
14+
}
15+
16+
@Override
17+
public String getName() {
18+
return "diffie-hellman-group-exchange-sha256";
19+
}
20+
}
21+
22+
public DHGexSHA256() {
23+
super(new SHA256());
24+
}
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package net.schmizz.sshj.transport.kex;
2+
3+
import net.schmizz.sshj.common.Buffer;
4+
import net.schmizz.sshj.transport.Transport;
5+
import net.schmizz.sshj.transport.TransportException;
6+
7+
import java.security.GeneralSecurityException;
8+
import java.util.Arrays;
9+
10+
/**
11+
* Created by ajvanerp on 29/10/15.
12+
*/
13+
public abstract class KeyExchangeBase implements KeyExchange {
14+
protected Transport trans;
15+
16+
private String V_S;
17+
private String V_C;
18+
private byte[] I_S;
19+
private byte[] I_C;
20+
21+
@Override
22+
public void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C) throws GeneralSecurityException, TransportException {
23+
this.trans = trans;
24+
this.V_S = V_S;
25+
this.V_C = V_C;
26+
this.I_S = Arrays.copyOf(I_S, I_S.length);
27+
this.I_C = Arrays.copyOf(I_C, I_C.length);
28+
}
29+
30+
protected Buffer.PlainBuffer initializedBuffer() {
31+
return new Buffer.PlainBuffer()
32+
.putString(V_C)
33+
.putString(V_S)
34+
.putString(I_C)
35+
.putString(I_S);
36+
}
37+
}

src/test/java/com/hierynomus/sshj/test/SshFixture.java

+4
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,8 @@ public void stopServer() {
158158
}
159159
}
160160
}
161+
162+
public SshServer getServer() {
163+
return server;
164+
}
161165
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package com.hierynomus.sshj.transport.kex;
2+
3+
import com.hierynomus.sshj.test.SshFixture;
4+
import net.schmizz.sshj.SSHClient;
5+
import org.apache.sshd.common.KeyExchange;
6+
import org.apache.sshd.common.NamedFactory;
7+
import org.apache.sshd.server.kex.DHGEX;
8+
import org.apache.sshd.server.kex.DHGEX256;
9+
import org.junit.After;
10+
import org.junit.Rule;
11+
import org.junit.Test;
12+
13+
import java.io.IOException;
14+
import java.util.Arrays;
15+
import java.util.Collections;
16+
17+
import static org.hamcrest.MatcherAssert.assertThat;
18+
19+
public class DiffieHellmanGroupExchangeTest {
20+
@Rule
21+
public SshFixture fixture = new SshFixture(false);
22+
23+
@After
24+
public void stopServer() {
25+
fixture.stopServer();
26+
}
27+
28+
@Test
29+
public void shouldKexWithGroupExchangeSha1() throws IOException {
30+
setupAndCheckKex(new DHGEX.Factory());
31+
}
32+
33+
@Test
34+
public void shouldKexWithGroupExchangeSha256() throws IOException {
35+
setupAndCheckKex(new DHGEX256.Factory());
36+
}
37+
38+
private void setupAndCheckKex(NamedFactory<KeyExchange> factory) throws IOException {
39+
fixture.getServer().setKeyExchangeFactories(Collections.singletonList(factory));
40+
fixture.start();
41+
SSHClient sshClient = fixture.setupConnectedDefaultClient();
42+
assertThat("should be connected", sshClient.isConnected());
43+
sshClient.disconnect();
44+
}
45+
}

0 commit comments

Comments
 (0)