Skip to content

Commit 3242548

Browse files
committed
[ADDED] Force reconnect to the server
Signed-off-by: Piotr Piotrowski <[email protected]>
1 parent 8894a27 commit 3242548

6 files changed

+243
-52
lines changed

example_test.go

+13
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,19 @@ func ExampleConn_Subscribe() {
8989
})
9090
}
9191

92+
func ExampleConn_Reconnect() {
93+
nc, _ := nats.Connect(nats.DefaultURL)
94+
defer nc.Close()
95+
96+
nc.Subscribe("foo", func(m *nats.Msg) {
97+
fmt.Printf("Received a message: %s\n", string(m.Data))
98+
})
99+
100+
// Reconnect to the server.
101+
// the subscription will be recreated after the reconnect.
102+
nc.Reconnect()
103+
}
104+
92105
// This Example shows a synchronous subscriber.
93106
func ExampleConn_SubscribeSync() {
94107
nc, _ := nats.Connect(nats.DefaultURL)

nats.go

+52-7
Original file line numberDiff line numberDiff line change
@@ -2161,6 +2161,47 @@ func (nc *Conn) waitForExits() {
21612161
nc.wg.Wait()
21622162
}
21632163

2164+
// Reconnect forces a reconnect attempt to the server.
2165+
// This is a non-blocking call and will start the reconnect
2166+
// process without waiting for it to complete.
2167+
//
2168+
// If the connection is already in the process of reconnecting,
2169+
// this call will force an immediate reconnect attempt (bypassing
2170+
// the current reconnect delay).
2171+
func (nc *Conn) Reconnect() error {
2172+
nc.mu.Lock()
2173+
defer nc.mu.Unlock()
2174+
2175+
if nc.isClosed() {
2176+
return ErrConnectionClosed
2177+
}
2178+
if nc.isReconnecting() {
2179+
// if we're already reconnecting, force a reconnect attempt
2180+
// even if we're in the middle of a backoff
2181+
if nc.rqch != nil {
2182+
close(nc.rqch)
2183+
}
2184+
return nil
2185+
}
2186+
2187+
// Clear any queued pongs
2188+
nc.clearPendingFlushCalls()
2189+
2190+
// Clear any queued and blocking requests.
2191+
nc.clearPendingRequestCalls()
2192+
2193+
// Stop ping timer if set.
2194+
nc.stopPingTimer()
2195+
2196+
// Go ahead and make sure we have flushed the outbound
2197+
nc.bw.flush()
2198+
nc.conn.Close()
2199+
2200+
nc.changeConnStatus(RECONNECTING)
2201+
go nc.doReconnect(nil, true)
2202+
return nil
2203+
}
2204+
21642205
// ConnectedUrl reports the connected server's URL
21652206
func (nc *Conn) ConnectedUrl() string {
21662207
if nc == nil {
@@ -2420,7 +2461,7 @@ func (nc *Conn) connect() (bool, error) {
24202461
nc.setup()
24212462
nc.changeConnStatus(RECONNECTING)
24222463
nc.bw.switchToPending()
2423-
go nc.doReconnect(ErrNoServers)
2464+
go nc.doReconnect(ErrNoServers, false)
24242465
err = nil
24252466
} else {
24262467
nc.current = nil
@@ -2720,7 +2761,7 @@ func (nc *Conn) stopPingTimer() {
27202761

27212762
// Try to reconnect using the option parameters.
27222763
// This function assumes we are allowed to reconnect.
2723-
func (nc *Conn) doReconnect(err error) {
2764+
func (nc *Conn) doReconnect(err error, forceReconnect bool) {
27242765
// We want to make sure we have the other watchers shutdown properly
27252766
// here before we proceed past this point.
27262767
nc.waitForExits()
@@ -2776,7 +2817,8 @@ func (nc *Conn) doReconnect(err error) {
27762817
break
27772818
}
27782819

2779-
doSleep := i+1 >= len(nc.srvPool)
2820+
doSleep := i+1 >= len(nc.srvPool) && !forceReconnect
2821+
forceReconnect = false
27802822
nc.mu.Unlock()
27812823

27822824
if !doSleep {
@@ -2803,6 +2845,12 @@ func (nc *Conn) doReconnect(err error) {
28032845
select {
28042846
case <-rqch:
28052847
rt.Stop()
2848+
2849+
// we need to reset the rqch channel to avoid
2850+
// closing a closed channel in the next iteration
2851+
nc.mu.Lock()
2852+
nc.rqch = make(chan struct{})
2853+
nc.mu.Unlock()
28062854
case <-rt.C:
28072855
}
28082856
}
@@ -2872,9 +2920,6 @@ func (nc *Conn) doReconnect(err error) {
28722920
// Done with the pending buffer
28732921
nc.bw.doneWithPending()
28742922

2875-
// This is where we are truly connected.
2876-
nc.status = CONNECTED
2877-
28782923
// Queue up the correct callback. If we are in initial connect state
28792924
// (using retry on failed connect), we will call the ConnectedCB,
28802925
// otherwise the ReconnectedCB.
@@ -2930,7 +2975,7 @@ func (nc *Conn) processOpErr(err error) {
29302975
// Clear any queued pongs, e.g. pending flush calls.
29312976
nc.clearPendingFlushCalls()
29322977

2933-
go nc.doReconnect(err)
2978+
go nc.doReconnect(err, false)
29342979
nc.mu.Unlock()
29352980
return
29362981
}

test/conn_test.go

+4-14
Original file line numberDiff line numberDiff line change
@@ -2946,16 +2946,6 @@ func TestRetryOnFailedConnectWithTLSError(t *testing.T) {
29462946
}
29472947

29482948
func TestConnStatusChangedEvents(t *testing.T) {
2949-
waitForStatus := func(t *testing.T, ch chan nats.Status, expected nats.Status) {
2950-
select {
2951-
case s := <-ch:
2952-
if s != expected {
2953-
t.Fatalf("Expected status: %s; got: %s", expected, s)
2954-
}
2955-
case <-time.After(5 * time.Second):
2956-
t.Fatalf("Timeout waiting for status %q", expected)
2957-
}
2958-
}
29592949
t.Run("default events", func(t *testing.T) {
29602950
s := RunDefaultServer()
29612951
nc, err := nats.Connect(s.ClientURL())
@@ -2978,15 +2968,15 @@ func TestConnStatusChangedEvents(t *testing.T) {
29782968
time.Sleep(50 * time.Millisecond)
29792969

29802970
s.Shutdown()
2981-
waitForStatus(t, newStatus, nats.RECONNECTING)
2971+
WaitOnChannel(t, newStatus, nats.RECONNECTING)
29822972

29832973
s = RunDefaultServer()
29842974
defer s.Shutdown()
29852975

2986-
waitForStatus(t, newStatus, nats.CONNECTED)
2976+
WaitOnChannel(t, newStatus, nats.CONNECTED)
29872977

29882978
nc.Close()
2989-
waitForStatus(t, newStatus, nats.CLOSED)
2979+
WaitOnChannel(t, newStatus, nats.CLOSED)
29902980

29912981
select {
29922982
case s := <-newStatus:
@@ -3019,7 +3009,7 @@ func TestConnStatusChangedEvents(t *testing.T) {
30193009
s = RunDefaultServer()
30203010
defer s.Shutdown()
30213011
nc.Close()
3022-
waitForStatus(t, newStatus, nats.CLOSED)
3012+
WaitOnChannel(t, newStatus, nats.CLOSED)
30233013

30243014
select {
30253015
case s := <-newStatus:

test/helper_test.go

+12
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ func WaitTime(ch chan bool, timeout time.Duration) error {
5454
return errors.New("timeout")
5555
}
5656

57+
func WaitOnChannel[T comparable](t *testing.T, ch <-chan T, expected T) {
58+
t.Helper()
59+
select {
60+
case s := <-ch:
61+
if s != expected {
62+
t.Fatalf("Expected result: %v; got: %v", expected, s)
63+
}
64+
case <-time.After(5 * time.Second):
65+
t.Fatalf("Timeout waiting for result %v", expected)
66+
}
67+
}
68+
5769
func stackFatalf(t tLogger, f string, args ...any) {
5870
lines := make([]string, 0, 32)
5971
msg := fmt.Sprintf(f, args...)

test/reconnect_test.go

+155-12
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,7 @@ func TestAuthExpiredReconnect(t *testing.T) {
853853

854854
jwtCB := func() (string, error) {
855855
claims := jwt.NewUserClaims("test")
856-
claims.Expires = time.Now().Add(500 * time.Millisecond).Unix()
856+
claims.Expires = time.Now().Add(time.Second).Unix()
857857
claims.Subject = upub
858858
jwt, err := claims.Encode(akp)
859859
if err != nil {
@@ -884,21 +884,164 @@ func TestAuthExpiredReconnect(t *testing.T) {
884884
case <-time.After(2 * time.Second):
885885
t.Fatal("Did not get the auth expired error")
886886
}
887-
select {
888-
case s := <-stasusCh:
889-
if s != nats.RECONNECTING {
890-
t.Fatalf("Expected to be in reconnecting state after jwt expires, got %v", s)
887+
WaitOnChannel(t, stasusCh, nats.RECONNECTING)
888+
WaitOnChannel(t, stasusCh, nats.CONNECTED)
889+
nc.Close()
890+
}
891+
892+
func TestForceReconnect(t *testing.T) {
893+
s := RunDefaultServer()
894+
895+
nc, err := nats.Connect(s.ClientURL(), nats.ReconnectWait(10*time.Second))
896+
if err != nil {
897+
t.Fatalf("Unexpected error on connect: %v", err)
898+
}
899+
// defer nc.Close()
900+
901+
statusCh := nc.StatusChanged(nats.RECONNECTING, nats.CONNECTED)
902+
defer close(statusCh)
903+
newStatus := make(chan nats.Status, 10)
904+
// non-blocking channel, so we need to be constantly listening
905+
go func() {
906+
for {
907+
s, ok := <-statusCh
908+
if !ok {
909+
return
910+
}
911+
newStatus <- s
891912
}
892-
case <-time.After(2 * time.Second):
893-
t.Fatal("Did not get the status change")
913+
}()
914+
915+
sub, err := nc.SubscribeSync("foo")
916+
if err != nil {
917+
t.Fatalf("Error on subscribe: %v", err)
918+
}
919+
if err := nc.Publish("foo", []byte("msg")); err != nil {
920+
t.Fatalf("Error on publish: %v", err)
921+
}
922+
_, err = sub.NextMsg(time.Second)
923+
if err != nil {
924+
t.Fatalf("Error getting message: %v", err)
925+
}
926+
927+
// Force a reconnect
928+
err = nc.Reconnect()
929+
if err != nil {
930+
t.Fatalf("Unexpected error on reconnect: %v", err)
931+
}
932+
933+
WaitOnChannel(t, newStatus, nats.RECONNECTING)
934+
WaitOnChannel(t, newStatus, nats.CONNECTED)
935+
936+
if err := nc.Publish("foo", []byte("msg")); err != nil {
937+
t.Fatalf("Error on publish: %v", err)
894938
}
939+
_, err = sub.NextMsg(time.Second)
940+
if err != nil {
941+
t.Fatalf("Error getting message: %v", err)
942+
}
943+
944+
// shutdown server and then force a reconnect
945+
s.Shutdown()
946+
WaitOnChannel(t, newStatus, nats.RECONNECTING)
947+
_, err = sub.NextMsg(100 * time.Millisecond)
948+
if err == nil {
949+
t.Fatal("Expected error getting message")
950+
}
951+
952+
// restart server
953+
s = RunDefaultServer()
954+
defer s.Shutdown()
955+
956+
if err := nc.Reconnect(); err != nil {
957+
t.Fatalf("Unexpected error on reconnect: %v", err)
958+
}
959+
// wait for the reconnect
960+
// because the connection has long ReconnectWait,
961+
// if force reconnect does not work, the test will timeout
962+
WaitOnChannel(t, newStatus, nats.CONNECTED)
963+
964+
if err := nc.Publish("foo", []byte("msg")); err != nil {
965+
t.Fatalf("Error on publish: %v", err)
966+
}
967+
_, err = sub.NextMsg(time.Second)
968+
if err != nil {
969+
t.Fatalf("Error getting message: %v", err)
970+
}
971+
nc.Close()
972+
}
973+
974+
func TestAuthExpiredForceReconnect(t *testing.T) {
975+
ts := runTrustServer()
976+
defer ts.Shutdown()
977+
978+
_, err := nats.Connect(ts.ClientURL())
979+
if err == nil {
980+
t.Fatalf("Expecting an error on connect")
981+
}
982+
ukp, err := nkeys.FromSeed(uSeed)
983+
if err != nil {
984+
t.Fatalf("Error creating user key pair: %v", err)
985+
}
986+
upub, err := ukp.PublicKey()
987+
if err != nil {
988+
t.Fatalf("Error getting user public key: %v", err)
989+
}
990+
akp, err := nkeys.FromSeed(aSeed)
991+
if err != nil {
992+
t.Fatalf("Error creating account key pair: %v", err)
993+
}
994+
995+
jwtCB := func() (string, error) {
996+
claims := jwt.NewUserClaims("test")
997+
claims.Expires = time.Now().Add(time.Second).Unix()
998+
claims.Subject = upub
999+
jwt, err := claims.Encode(akp)
1000+
if err != nil {
1001+
return "", err
1002+
}
1003+
return jwt, nil
1004+
}
1005+
sigCB := func(nonce []byte) ([]byte, error) {
1006+
kp, _ := nkeys.FromSeed(uSeed)
1007+
sig, _ := kp.Sign(nonce)
1008+
return sig, nil
1009+
}
1010+
1011+
errCh := make(chan error, 1)
1012+
nc, err := nats.Connect(ts.ClientURL(), nats.UserJWT(jwtCB, sigCB), nats.ReconnectWait(10*time.Second),
1013+
nats.ErrorHandler(func(_ *nats.Conn, _ *nats.Subscription, err error) {
1014+
errCh <- err
1015+
}))
1016+
if err != nil {
1017+
t.Fatalf("Expected to connect, got %v", err)
1018+
}
1019+
defer nc.Close()
1020+
statusCh := nc.StatusChanged(nats.RECONNECTING, nats.CONNECTED)
1021+
defer close(statusCh)
1022+
newStatus := make(chan nats.Status, 10)
1023+
// non-blocking channel, so we need to be constantly listening
1024+
go func() {
1025+
for {
1026+
s, ok := <-statusCh
1027+
if !ok {
1028+
return
1029+
}
1030+
newStatus <- s
1031+
}
1032+
}()
1033+
time.Sleep(100 * time.Millisecond)
8951034
select {
896-
case s := <-stasusCh:
897-
if s != nats.CONNECTED {
898-
t.Fatalf("Expected to reconnect, got %v", s)
1035+
case err := <-errCh:
1036+
if !errors.Is(err, nats.ErrAuthExpired) {
1037+
t.Fatalf("Expected auth expired error, got %v", err)
8991038
}
9001039
case <-time.After(2 * time.Second):
901-
t.Fatal("Did not get the status change")
1040+
t.Fatal("Did not get the auth expired error")
9021041
}
903-
nc.Close()
1042+
if err := nc.Reconnect(); err != nil {
1043+
t.Fatalf("Unexpected error on reconnect: %v", err)
1044+
}
1045+
WaitOnChannel(t, newStatus, nats.RECONNECTING)
1046+
WaitOnChannel(t, newStatus, nats.CONNECTED)
9041047
}

0 commit comments

Comments
 (0)