Skip to content

Commit

Permalink
[FIXED] Drain() infinite loop and add test for concurrent Next()
Browse files Browse the repository at this point in the history
…calls (#1525)

* Added failing graceful shutdown test for MessagesContext.Drain method

* Minimize lock scope in pullSubscription.Next to allow for cleanup

Fixes possible deadlock when Next() is waiting and holding the lock and
cleanup() waiting for the lock to unsubscribe.

* Remove unused drained channel from pullSubscription

* Added test for auto unsubscribe with concurrent calls

* Revert locking in Next and remove cleanup lock

* Prevent hanging in the auto-unsubscribe test

Added comment for the cleanup function on why it doesn't need to hold
the lock.
  • Loading branch information
mdawar authored Jan 15, 2024
1 parent 61196eb commit a8a8d18
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 61 deletions.
15 changes: 5 additions & 10 deletions jetstream/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ type (
closed uint32
draining uint32
done chan struct{}
drained chan struct{}
connStatusChanged chan nats.Status
fetchNext chan *pullRequest
consumeOpts *consumeOpts
Expand Down Expand Up @@ -476,7 +475,6 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error
id: consumeID,
consumer: p,
done: make(chan struct{}, 1),
drained: make(chan struct{}, 1),
msgs: msgs,
errs: make(chan error, 1),
fetchNext: make(chan *pullRequest, 1),
Expand Down Expand Up @@ -560,12 +558,6 @@ func (s *pullSubscription) Next() (Msg, error) {
for {
s.checkPending()
select {
case <-s.done:
drainMode := atomic.LoadUint32(&s.draining) == 1
if drainMode {
continue
}
return nil, ErrMsgIteratorClosed
case msg, ok := <-s.msgs:
if !ok {
// if msgs channel is closed, it means that subscription was either drained or stopped
Expand Down Expand Up @@ -914,8 +906,11 @@ func (s *pullSubscription) scheduleHeartbeatCheck(dur time.Duration) *hbMonitor
}

func (s *pullSubscription) cleanup() {
s.Lock()
defer s.Unlock()
// For now this function does not need to hold the lock.
// Holding the lock here might cause a deadlock if Next()
// is already holding the lock and waiting.
// The fields that are read (subscription, hbMonitor)
// are read only (Only written on creation of pullSubscription).
if s.subscription == nil || !s.subscription.IsValid() {
return
}
Expand Down
217 changes: 166 additions & 51 deletions jetstream/test/pull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,111 @@ func TestPullConsumerMessages(t *testing.T) {
}
})

t.Run("with auto unsubscribe concurrent", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "test", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

it, err := c.Messages(jetstream.StopAfter(50), jetstream.PullMaxMessages(40))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

for i := 0; i < 100; i++ {
if _, err := js.Publish(ctx, "FOO.A", []byte("msg")); err != nil {
t.Fatalf("Unexpected error during publish: %s", err)
}
}

var mu sync.Mutex // Mutex to guard the msgs slice.
msgs := make([]jetstream.Msg, 0)
var wg sync.WaitGroup

wg.Add(50)
for i := 0; i < 50; i++ {
go func() {
defer wg.Done()

msg, err := it.Next()
if err != nil {
return
}

ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := msg.DoubleAck(ctx); err == nil {
// Only append the msg if ack is successful.
mu.Lock()
msgs = append(msgs, msg)
mu.Unlock()
}
}()
}

wg.Wait()

// Call Next in a goroutine so we can timeout if it doesn't return.
errs := make(chan error)
go func() {
// This call should return the error ErrMsgIteratorClosed.
_, err := it.Next()
errs <- err
}()

timer := time.NewTimer(5 * time.Second)
defer timer.Stop()

select {
case <-timer.C:
t.Fatal("Timed out waiting for Next() to return")
case err := <-errs:
if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
t.Fatalf("Unexpected error: %v", err)
}
}

mu.Lock()
wantLen, gotLen := 50, len(msgs)
mu.Unlock()
if wantLen != gotLen {
t.Fatalf("Unexpected received message count; want %d; got %d", wantLen, gotLen)
}

ci, err := c.Info(ctx)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if ci.NumPending != 50 {
t.Fatalf("Unexpected number of pending messages; want 50; got %d", ci.NumPending)
}
if ci.NumAckPending != 0 {
t.Fatalf("Unexpected number of ack pending messages; want 0; got %d", ci.NumAckPending)
}
if ci.NumWaiting != 0 {
t.Fatalf("Unexpected number of waiting pull requests; want 0; got %d", ci.NumWaiting)
}
})

t.Run("create iterator, stop, then create again", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)
Expand Down Expand Up @@ -1293,69 +1398,79 @@ func TestPullConsumerMessages(t *testing.T) {
})

t.Run("with graceful shutdown", func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)

nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
cases := map[string]func(jetstream.MessagesContext){
"stop": func(mc jetstream.MessagesContext) { mc.Stop() },
"drain": func(mc jetstream.MessagesContext) { mc.Drain() },
}

js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()
for name, unsubscribe := range cases {
t.Run(name, func(t *testing.T) {
srv := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, srv)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

it, err := c.Messages()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
nc, err := nats.Connect(srv.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

publishTestMsgs(t, nc)
js, err := jetstream.New(nc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer nc.Close()

errs := make(chan error)
msgs := make([]jetstream.Msg, 0)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

go func() {
for {
msg, err := it.Next()
it, err := c.Messages()
if err != nil {
errs <- err
return
t.Fatalf("Unexpected error: %v", err)
}
msg.Ack()
msgs = append(msgs, msg)
}
}()

time.Sleep(10 * time.Millisecond)
it.Stop() // Next() should return ErrMsgIteratorClosed
publishTestMsgs(t, nc)

timeout := time.NewTimer(5 * time.Second)
errs := make(chan error)
msgs := make([]jetstream.Msg, 0)

select {
case <-timeout.C:
t.Fatal("Timed out waiting for Next() to return after Stop()")
case err := <-errs:
if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
t.Fatalf("Unexpected error: %v", err)
}
go func() {
for {
msg, err := it.Next()
if err != nil {
errs <- err
return
}
msg.Ack()
msgs = append(msgs, msg)
}
}()

if len(msgs) != len(testMsgs) {
t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
}
time.Sleep(10 * time.Millisecond)
unsubscribe(it) // Next() should return ErrMsgIteratorClosed

timer := time.NewTimer(5 * time.Second)
defer timer.Stop()

select {
case <-timer.C:
t.Fatal("Timed out waiting for Next() to return")
case err := <-errs:
if !errors.Is(err, jetstream.ErrMsgIteratorClosed) {
t.Fatalf("Unexpected error: %v", err)
}

if len(msgs) != len(testMsgs) {
t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs))
}
}
})
}
})

Expand Down

0 comments on commit a8a8d18

Please sign in to comment.