diff --git a/pubsub/batcher/batcher.go b/pubsub/batcher/batcher.go index 1a6bae812b..bbcc3de266 100644 --- a/pubsub/batcher/batcher.go +++ b/pubsub/batcher/batcher.go @@ -213,16 +213,17 @@ func (b *Batcher) AddNoWait(item interface{}) <-chan error { } if batch == nil && len(b.pending) > 0 && b.opts.BatchTimeout > 0 { - // If the batch size timeout is zero, this is one of the first items to - // be added to the batch under the minimum batch size. Record when this - // happens so that .nextBatch() can grab the batch on timeout. - if b.batchSizeTimeout.IsZero() { - b.batchSizeTimeout = time.Now() - } // Ensure that we send the batch after the given timeout. Only one // concurrent process can run this goroutine, ensuring that we don't // duplicate work. if atomic.CompareAndSwapInt32(&b.batchTimeoutRunning, 0, 1) { + // If the batch size timeout is zero, this is one of the first items to + // be added to the batch under the minimum batch size. Record when this + // happens so that .nextBatch() can grab the batch on timeout. + if b.batchSizeTimeout.IsZero() { + b.batchSizeTimeout = time.Now() + } + go func() { <-time.After(b.opts.BatchTimeout) b.batchTimeoutRunning = 0 @@ -300,10 +301,12 @@ func (b *Batcher) respectMinBatchSize() bool { // If we're shutting down, do not respect minimums. This takes priority. return false } - if b.opts.BatchTimeout > 0 && time.Since(b.batchSizeTimeout) >= b.opts.BatchTimeout { + if b.opts.BatchTimeout > 0 { // If we have a maximum wait before sending batches below the minimum, and we've // waited longer than that period, do not respect minimum batches and send! - return false + if !b.batchSizeTimeout.IsZero() && time.Since(b.batchSizeTimeout) >= b.opts.BatchTimeout { + return false + } } // At this point, either we're not shutting down and we're not forcing a batch // due to timeouts. Respect the batch size. diff --git a/pubsub/batcher/batcher_test.go b/pubsub/batcher/batcher_test.go index 9b0ee3c055..3fe3fff1ed 100644 --- a/pubsub/batcher/batcher_test.go +++ b/pubsub/batcher/batcher_test.go @@ -171,6 +171,35 @@ func TestMinBatchSize(t *testing.T) { } } +// TestMinBatchSizeFlushesAfterTimeout ensures that Shutdown() flushes batches, even if +// the pending count is less than the minimum batch size. +func TestMinBatchSizeFlushesAfterTimeout(t *testing.T) { + var got [][]int + + batchSize := 3 + opts := &batcher.Options{MinBatchSize: batchSize, BatchTimeout: 10 * time.Millisecond} + + b := batcher.New(reflect.TypeOf(int(0)), opts, func(items interface{}) error { + got = append(got, items.([]int)) + return nil + }) + for i := 0; i < (batchSize - 1); i++ { + b.AddNoWait(i) + } + + // Ensure that we've received nothing + if len(got) > 0 { + t.Errorf("got batch unexpectedly: %+v", got) + } + + <-time.After(opts.BatchTimeout + 5*time.Millisecond) + + want := [][]int{{0, 1}} + if !cmp.Equal(got, want) { + t.Errorf("got %+v, want %+v after timeout", got, want) + } +} + // TestMinBatchSizeFlushesOnShutdown ensures that Shutdown() flushes batches, even if // the pending count is less than the minimum batch size. func TestMinBatchSizeFlushesOnShutdown(t *testing.T) {