From 29dd974e90c1d8fa3ab1ede72fb70937f274d3b5 Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Sat, 27 Apr 2024 02:07:17 +0200 Subject: [PATCH 1/3] [ADDED] Force reconnect to the server Signed-off-by: Piotr Piotrowski --- example_test.go | 13 ++++ nats.go | 59 +++++++++++++-- test/conn_test.go | 18 +---- test/helper_test.go | 12 +++ test/reconnect_test.go | 166 ++++++++++++++++++++++++++++++++++++++--- test/sub_test.go | 26 ++----- 6 files changed, 242 insertions(+), 52 deletions(-) diff --git a/example_test.go b/example_test.go index 6aa93636c..53b608e36 100644 --- a/example_test.go +++ b/example_test.go @@ -89,6 +89,19 @@ func ExampleConn_Subscribe() { }) } +func ExampleConn_Reconnect() { + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + nc.Subscribe("foo", func(m *nats.Msg) { + fmt.Printf("Received a message: %s\n", string(m.Data)) + }) + + // Reconnect to the server. + // the subscription will be recreated after the reconnect. + nc.Reconnect() +} + // This Example shows a synchronous subscriber. func ExampleConn_SubscribeSync() { nc, _ := nats.Connect(nats.DefaultURL) diff --git a/nats.go b/nats.go index 8c0796a89..c3b0170ae 100644 --- a/nats.go +++ b/nats.go @@ -2161,6 +2161,47 @@ func (nc *Conn) waitForExits() { nc.wg.Wait() } +// Reconnect forces a reconnect attempt to the server. +// This is a non-blocking call and will start the reconnect +// process without waiting for it to complete. +// +// If the connection is already in the process of reconnecting, +// this call will force an immediate reconnect attempt (bypassing +// the current reconnect delay). +func (nc *Conn) Reconnect() error { + nc.mu.Lock() + defer nc.mu.Unlock() + + if nc.isClosed() { + return ErrConnectionClosed + } + if nc.isReconnecting() { + // if we're already reconnecting, force a reconnect attempt + // even if we're in the middle of a backoff + if nc.rqch != nil { + close(nc.rqch) + } + return nil + } + + // Clear any queued pongs + nc.clearPendingFlushCalls() + + // Clear any queued and blocking requests. + nc.clearPendingRequestCalls() + + // Stop ping timer if set. + nc.stopPingTimer() + + // Go ahead and make sure we have flushed the outbound + nc.bw.flush() + nc.conn.Close() + + nc.changeConnStatus(RECONNECTING) + go nc.doReconnect(nil, true) + return nil +} + // ConnectedUrl reports the connected server's URL func (nc *Conn) ConnectedUrl() string { if nc == nil { @@ -2420,7 +2461,7 @@ func (nc *Conn) connect() (bool, error) { nc.setup() nc.changeConnStatus(RECONNECTING) nc.bw.switchToPending() - go nc.doReconnect(ErrNoServers) + go nc.doReconnect(ErrNoServers, false) err = nil } else { nc.current = nil @@ -2720,7 +2761,7 @@ func (nc *Conn) stopPingTimer() { // Try to reconnect using the option parameters. // This function assumes we are allowed to reconnect. -func (nc *Conn) doReconnect(err error) { +func (nc *Conn) doReconnect(err error, forceReconnect bool) { // We want to make sure we have the other watchers shutdown properly // here before we proceed past this point. nc.waitForExits() @@ -2776,7 +2817,8 @@ func (nc *Conn) doReconnect(err error) { break } - doSleep := i+1 >= len(nc.srvPool) + doSleep := i+1 >= len(nc.srvPool) && !forceReconnect + forceReconnect = false nc.mu.Unlock() if !doSleep { @@ -2803,6 +2845,12 @@ func (nc *Conn) doReconnect(err error) { select { case <-rqch: rt.Stop() + + // we need to reset the rqch channel to avoid + // closing a closed channel in the next iteration + nc.mu.Lock() + nc.rqch = make(chan struct{}) + nc.mu.Unlock() case <-rt.C: } } @@ -2872,9 +2920,6 @@ func (nc *Conn) doReconnect(err error) { // Done with the pending buffer nc.bw.doneWithPending() - // This is where we are truly connected. - nc.status = CONNECTED - // Queue up the correct callback. If we are in initial connect state // (using retry on failed connect), we will call the ConnectedCB, // otherwise the ReconnectedCB. @@ -2930,7 +2975,7 @@ func (nc *Conn) processOpErr(err error) { // Clear any queued pongs, e.g. pending flush calls. nc.clearPendingFlushCalls() - go nc.doReconnect(err) + go nc.doReconnect(err, false) nc.mu.Unlock() return } diff --git a/test/conn_test.go b/test/conn_test.go index 7e5fcab01..afc5025b3 100644 --- a/test/conn_test.go +++ b/test/conn_test.go @@ -2946,16 +2946,6 @@ func TestRetryOnFailedConnectWithTLSError(t *testing.T) { } func TestConnStatusChangedEvents(t *testing.T) { - waitForStatus := func(t *testing.T, ch chan nats.Status, expected nats.Status) { - select { - case s := <-ch: - if s != expected { - t.Fatalf("Expected status: %s; got: %s", expected, s) - } - case <-time.After(5 * time.Second): - t.Fatalf("Timeout waiting for status %q", expected) - } - } t.Run("default events", func(t *testing.T) { s := RunDefaultServer() nc, err := nats.Connect(s.ClientURL()) @@ -2978,15 +2968,15 @@ func TestConnStatusChangedEvents(t *testing.T) { time.Sleep(50 * time.Millisecond) s.Shutdown() - waitForStatus(t, newStatus, nats.RECONNECTING) + WaitOnChannel(t, newStatus, nats.RECONNECTING) s = RunDefaultServer() defer s.Shutdown() - waitForStatus(t, newStatus, nats.CONNECTED) + WaitOnChannel(t, newStatus, nats.CONNECTED) nc.Close() - waitForStatus(t, newStatus, nats.CLOSED) + WaitOnChannel(t, newStatus, nats.CLOSED) select { case s := <-newStatus: @@ -3019,7 +3009,7 @@ func TestConnStatusChangedEvents(t *testing.T) { s = RunDefaultServer() defer s.Shutdown() nc.Close() - waitForStatus(t, newStatus, nats.CLOSED) + WaitOnChannel(t, newStatus, nats.CLOSED) select { case s := <-newStatus: diff --git a/test/helper_test.go b/test/helper_test.go index 9c04a40f9..7f2aedf0c 100644 --- a/test/helper_test.go +++ b/test/helper_test.go @@ -54,6 +54,18 @@ func WaitTime(ch chan bool, timeout time.Duration) error { return errors.New("timeout") } +func WaitOnChannel[T comparable](t *testing.T, ch <-chan T, expected T) { + t.Helper() + select { + case s := <-ch: + if s != expected { + t.Fatalf("Expected result: %v; got: %v", expected, s) + } + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for result %v", expected) + } +} + func stackFatalf(t tLogger, f string, args ...any) { lines := make([]string, 0, 32) msg := fmt.Sprintf(f, args...) diff --git a/test/reconnect_test.go b/test/reconnect_test.go index 66cc9b6ca..550150fe5 100644 --- a/test/reconnect_test.go +++ b/test/reconnect_test.go @@ -853,7 +853,7 @@ func TestAuthExpiredReconnect(t *testing.T) { jwtCB := func() (string, error) { claims := jwt.NewUserClaims("test") - claims.Expires = time.Now().Add(500 * time.Millisecond).Unix() + claims.Expires = time.Now().Add(time.Second).Unix() claims.Subject = upub jwt, err := claims.Encode(akp) if err != nil { @@ -884,21 +884,163 @@ func TestAuthExpiredReconnect(t *testing.T) { case <-time.After(2 * time.Second): t.Fatal("Did not get the auth expired error") } - select { - case s := <-stasusCh: - if s != nats.RECONNECTING { - t.Fatalf("Expected to be in reconnecting state after jwt expires, got %v", s) + WaitOnChannel(t, stasusCh, nats.RECONNECTING) + WaitOnChannel(t, stasusCh, nats.CONNECTED) + nc.Close() +} + +func TestForceReconnect(t *testing.T) { + s := RunDefaultServer() + + nc, err := nats.Connect(s.ClientURL(), nats.ReconnectWait(10*time.Second)) + if err != nil { + t.Fatalf("Unexpected error on connect: %v", err) + } + + statusCh := nc.StatusChanged(nats.RECONNECTING, nats.CONNECTED) + defer close(statusCh) + newStatus := make(chan nats.Status, 10) + // non-blocking channel, so we need to be constantly listening + go func() { + for { + s, ok := <-statusCh + if !ok { + return + } + newStatus <- s } - case <-time.After(2 * time.Second): - t.Fatal("Did not get the status change") + }() + + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + + // Force a reconnect + err = nc.Reconnect() + if err != nil { + t.Fatalf("Unexpected error on reconnect: %v", err) + } + + WaitOnChannel(t, newStatus, nats.RECONNECTING) + WaitOnChannel(t, newStatus, nats.CONNECTED) + + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + + // shutdown server and then force a reconnect + s.Shutdown() + WaitOnChannel(t, newStatus, nats.RECONNECTING) + _, err = sub.NextMsg(100 * time.Millisecond) + if err == nil { + t.Fatal("Expected error getting message") + } + + // restart server + s = RunDefaultServer() + defer s.Shutdown() + + if err := nc.Reconnect(); err != nil { + t.Fatalf("Unexpected error on reconnect: %v", err) + } + // wait for the reconnect + // because the connection has long ReconnectWait, + // if force reconnect does not work, the test will timeout + WaitOnChannel(t, newStatus, nats.CONNECTED) + + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + nc.Close() +} + +func TestAuthExpiredForceReconnect(t *testing.T) { + ts := runTrustServer() + defer ts.Shutdown() + + _, err := nats.Connect(ts.ClientURL()) + if err == nil { + t.Fatalf("Expecting an error on connect") + } + ukp, err := nkeys.FromSeed(uSeed) + if err != nil { + t.Fatalf("Error creating user key pair: %v", err) + } + upub, err := ukp.PublicKey() + if err != nil { + t.Fatalf("Error getting user public key: %v", err) + } + akp, err := nkeys.FromSeed(aSeed) + if err != nil { + t.Fatalf("Error creating account key pair: %v", err) + } + + jwtCB := func() (string, error) { + claims := jwt.NewUserClaims("test") + claims.Expires = time.Now().Add(time.Second).Unix() + claims.Subject = upub + jwt, err := claims.Encode(akp) + if err != nil { + return "", err + } + return jwt, nil + } + sigCB := func(nonce []byte) ([]byte, error) { + kp, _ := nkeys.FromSeed(uSeed) + sig, _ := kp.Sign(nonce) + return sig, nil + } + + errCh := make(chan error, 1) + nc, err := nats.Connect(ts.ClientURL(), nats.UserJWT(jwtCB, sigCB), nats.ReconnectWait(10*time.Second), + nats.ErrorHandler(func(_ *nats.Conn, _ *nats.Subscription, err error) { + errCh <- err + })) + if err != nil { + t.Fatalf("Expected to connect, got %v", err) + } + defer nc.Close() + statusCh := nc.StatusChanged(nats.RECONNECTING, nats.CONNECTED) + defer close(statusCh) + newStatus := make(chan nats.Status, 10) + // non-blocking channel, so we need to be constantly listening + go func() { + for { + s, ok := <-statusCh + if !ok { + return + } + newStatus <- s + } + }() + time.Sleep(100 * time.Millisecond) select { - case s := <-stasusCh: - if s != nats.CONNECTED { - t.Fatalf("Expected to reconnect, got %v", s) + case err := <-errCh: + if !errors.Is(err, nats.ErrAuthExpired) { + t.Fatalf("Expected auth expired error, got %v", err) } case <-time.After(2 * time.Second): - t.Fatal("Did not get the status change") + t.Fatal("Did not get the auth expired error") } - nc.Close() + if err := nc.Reconnect(); err != nil { + t.Fatalf("Unexpected error on reconnect: %v", err) + } + WaitOnChannel(t, newStatus, nats.RECONNECTING) + WaitOnChannel(t, newStatus, nats.CONNECTED) } diff --git a/test/sub_test.go b/test/sub_test.go index 0bf2880c1..f0f83a8d5 100644 --- a/test/sub_test.go +++ b/test/sub_test.go @@ -1617,18 +1617,6 @@ func TestSubscribe_ClosedHandler(t *testing.T) { } func TestSubscriptionEvents(t *testing.T) { - - waitForStatus := func(t *testing.T, ch <-chan nats.SubStatus, expected nats.SubStatus) { - t.Helper() - select { - case s := <-ch: - if s != expected { - t.Fatalf("Expected status: %s; got: %s", expected, s) - } - case <-time.After(5 * time.Second): - t.Fatalf("Timeout waiting for status %q", expected) - } - } t.Run("default events", func(t *testing.T) { s := RunDefaultServer() defer s.Shutdown() @@ -1651,19 +1639,19 @@ func TestSubscriptionEvents(t *testing.T) { status := sub.StatusChanged() // initial status - waitForStatus(t, status, nats.SubscriptionActive) + WaitOnChannel(t, status, nats.SubscriptionActive) for i := 0; i < 11; i++ { nc.Publish("foo", []byte("Hello")) } - waitForStatus(t, status, nats.SubscriptionSlowConsumer) + WaitOnChannel(t, status, nats.SubscriptionSlowConsumer) close(blockChan) sub.Drain() - waitForStatus(t, status, nats.SubscriptionDraining) + WaitOnChannel(t, status, nats.SubscriptionDraining) - waitForStatus(t, status, nats.SubscriptionClosed) + WaitOnChannel(t, status, nats.SubscriptionClosed) }) t.Run("slow consumer event only", func(t *testing.T) { @@ -1691,7 +1679,7 @@ func TestSubscriptionEvents(t *testing.T) { for i := 0; i < 20; i++ { nc.Publish("foo", []byte("Hello")) } - waitForStatus(t, status, nats.SubscriptionSlowConsumer) + WaitOnChannel(t, status, nats.SubscriptionSlowConsumer) close(blockChan) // now try with sync sub @@ -1706,7 +1694,7 @@ func TestSubscriptionEvents(t *testing.T) { for i := 0; i < 20; i++ { nc.Publish("foo", []byte("Hello")) } - waitForStatus(t, status, nats.SubscriptionSlowConsumer) + WaitOnChannel(t, status, nats.SubscriptionSlowConsumer) }) t.Run("do not block channel if it's not read", func(t *testing.T) { @@ -1730,7 +1718,7 @@ func TestSubscriptionEvents(t *testing.T) { } sub.SetPendingLimits(10, 1024) status := sub.StatusChanged() - waitForStatus(t, status, nats.SubscriptionActive) + WaitOnChannel(t, status, nats.SubscriptionActive) // chan length is 10, so make sure we switch state more times for i := 0; i < 20; i++ { From b92be5d5aa90670659c746a4f6c0bc0d81a394c9 Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Mon, 6 May 2024 21:52:01 +0200 Subject: [PATCH 2/3] Add test veryfing if force reconnct works with NoReconnect option Signed-off-by: Piotr Piotrowski --- test/reconnect_test.go | 55 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/test/reconnect_test.go b/test/reconnect_test.go index 550150fe5..6cd59a0fd 100644 --- a/test/reconnect_test.go +++ b/test/reconnect_test.go @@ -970,6 +970,61 @@ func TestForceReconnect(t *testing.T) { nc.Close() } +func TestForceReconnectDisallowReconnect(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc, err := nats.Connect(s.ClientURL(), nats.NoReconnect()) + if err != nil { + t.Fatalf("Unexpected error on connect: %v", err) + } + defer nc.Close() + + statusCh := nc.StatusChanged(nats.RECONNECTING, nats.CONNECTED) + defer close(statusCh) + newStatus := make(chan nats.Status, 10) + // non-blocking channel, so we need to be constantly listening + go func() { + for { + s, ok := <-statusCh + if !ok { + return + } + newStatus <- s + } + }() + + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + + // Force a reconnect + err = nc.Reconnect() + if err != nil { + t.Fatalf("Unexpected error on reconnect: %v", err) + } + + WaitOnChannel(t, newStatus, nats.RECONNECTING) + WaitOnChannel(t, newStatus, nats.CONNECTED) + + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + +} + func TestAuthExpiredForceReconnect(t *testing.T) { ts := runTrustServer() defer ts.Shutdown() From 8f8f9ae0c22b71fa319fcee841eac08e27be9933 Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Wed, 8 May 2024 12:00:46 -0700 Subject: [PATCH 3/3] Rename to ForceReconnect Signed-off-by: Piotr Piotrowski --- example_test.go | 4 ++-- nats.go | 4 ++-- test/reconnect_test.go | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/example_test.go b/example_test.go index 53b608e36..3ad367320 100644 --- a/example_test.go +++ b/example_test.go @@ -89,7 +89,7 @@ func ExampleConn_Subscribe() { }) } -func ExampleConn_Reconnect() { +func ExampleConn_ForceReconnect() { nc, _ := nats.Connect(nats.DefaultURL) defer nc.Close() @@ -99,7 +99,7 @@ func ExampleConn_Reconnect() { // Reconnect to the server. // the subscription will be recreated after the reconnect. - nc.Reconnect() + nc.ForceReconnect() } // This Example shows a synchronous subscriber. diff --git a/nats.go b/nats.go index c3b0170ae..d94c9a9c7 100644 --- a/nats.go +++ b/nats.go @@ -2161,14 +2161,14 @@ func (nc *Conn) waitForExits() { nc.wg.Wait() } -// Reconnect forces a reconnect attempt to the server. +// ForceReconnect forces a reconnect attempt to the server. // This is a non-blocking call and will start the reconnect // process without waiting for it to complete. // // If the connection is already in the process of reconnecting, // this call will force an immediate reconnect attempt (bypassing // the current reconnect delay). -func (nc *Conn) Reconnect() error { +func (nc *Conn) ForceReconnect() error { nc.mu.Lock() defer nc.mu.Unlock() diff --git a/test/reconnect_test.go b/test/reconnect_test.go index 6cd59a0fd..e543db72e 100644 --- a/test/reconnect_test.go +++ b/test/reconnect_test.go @@ -924,7 +924,7 @@ func TestForceReconnect(t *testing.T) { } // Force a reconnect - err = nc.Reconnect() + err = nc.ForceReconnect() if err != nil { t.Fatalf("Unexpected error on reconnect: %v", err) } @@ -952,7 +952,7 @@ func TestForceReconnect(t *testing.T) { s = RunDefaultServer() defer s.Shutdown() - if err := nc.Reconnect(); err != nil { + if err := nc.ForceReconnect(); err != nil { t.Fatalf("Unexpected error on reconnect: %v", err) } // wait for the reconnect @@ -1007,7 +1007,7 @@ func TestForceReconnectDisallowReconnect(t *testing.T) { } // Force a reconnect - err = nc.Reconnect() + err = nc.ForceReconnect() if err != nil { t.Fatalf("Unexpected error on reconnect: %v", err) } @@ -1093,7 +1093,7 @@ func TestAuthExpiredForceReconnect(t *testing.T) { case <-time.After(2 * time.Second): t.Fatal("Did not get the auth expired error") } - if err := nc.Reconnect(); err != nil { + if err := nc.ForceReconnect(); err != nil { t.Fatalf("Unexpected error on reconnect: %v", err) } WaitOnChannel(t, newStatus, nats.RECONNECTING)