Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize rate limiter for heavier load #125

Merged
merged 2 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 79 additions & 44 deletions v1/rate_limit.go
Original file line number Diff line number Diff line change
@@ -1,100 +1,135 @@
package v1

import (
"hash/fnv"
"runtime"
"sync"
"sync/atomic"
"time"
)

// NoopLimiter implements Limiter but doesn't limit anything.
// Use double the CPU count for sharding
const shardsPerCoreMultiplier = 2

var NoopLimiter Limiter = &noopLimiter{}

type token struct {
rps atomic.Uint32
lastUse atomic.Value
rps uint32
lastUse int64 // Unix timestamp in nanoseconds
}

// Limiter implements some form of rate limiting.
// Limiter interface for rate-limiting.
type Limiter interface {
// Obtain the right to send a request. Should lock the execution if current goroutine needs to wait.
Obtain(string)
Obtain(id string)
}

// TokensBucket implements basic Limiter with fixed window and fixed amount of tokens per window.
// TokensBucket implements a sharded rate limiter with fixed window and tokens.
type TokensBucket struct {
maxRPS uint32
tokens sync.Map
unusedTokenTime time.Duration
unusedTokenTime int64 // in nanoseconds
checkTokenTime time.Duration
shards []*tokenShard
shardCount uint32
cancel atomic.Bool
sleep sleeper
}

// NewTokensBucket constructs TokensBucket with provided parameters.
type tokenShard struct {
tokens map[string]*token
mu sync.Mutex
}

// NewTokensBucket creates a sharded token bucket limiter.
func NewTokensBucket(maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration) Limiter {
shardCount := uint32(runtime.NumCPU() * shardsPerCoreMultiplier)
shards := make([]*tokenShard, shardCount)
for i := range shards {
shards[i] = &tokenShard{tokens: make(map[string]*token)}
}

bucket := &TokensBucket{
maxRPS: maxRPS,
unusedTokenTime: unusedTokenTime,
unusedTokenTime: unusedTokenTime.Nanoseconds(),
checkTokenTime: checkTokenTime,
shards: shards,
shardCount: shardCount,
sleep: realSleeper{},
}

go bucket.deleteUnusedToken()
runtime.SetFinalizer(bucket, destructBasket)
go bucket.cleanupRoutine()
runtime.SetFinalizer(bucket, destructBucket)
return bucket
}

// Obtain request hit. Will throttle RPS.
func (m *TokensBucket) Obtain(id string) {
val, ok := m.tokens.Load(id)
if !ok {
token := &token{}
token.lastUse.Store(time.Now())
token.rps.Store(1)
m.tokens.Store(id, token)
shard := m.getShard(id)

shard.mu.Lock()
defer shard.mu.Unlock()

item, exists := shard.tokens[id]
now := time.Now().UnixNano()

if !exists {
shard.tokens[id] = &token{
rps: 1,
lastUse: now,
}
return
}

token := val.(*token)
sleepTime := time.Second - time.Since(token.lastUse.Load().(time.Time))
sleepTime := int64(time.Second) - (now - item.lastUse)
if sleepTime <= 0 {
token.lastUse.Store(time.Now())
token.rps.Store(0)
} else if token.rps.Load() >= m.maxRPS {
m.sleep.Sleep(sleepTime)
token.lastUse.Store(time.Now())
token.rps.Store(0)
item.lastUse = now
atomic.StoreUint32(&item.rps, 1)
} else if atomic.LoadUint32(&item.rps) >= m.maxRPS {
m.sleep.Sleep(time.Duration(sleepTime))
item.lastUse = time.Now().UnixNano()
atomic.StoreUint32(&item.rps, 1)
} else {
atomic.AddUint32(&item.rps, 1)
}
token.rps.Add(1)
}

func destructBasket(m *TokensBucket) {
m.cancel.Store(true)
func (m *TokensBucket) getShard(id string) *tokenShard {
hash := fnv.New32a()
_, _ = hash.Write([]byte(id))
return m.shards[hash.Sum32()%m.shardCount]
}

func (m *TokensBucket) deleteUnusedToken() {
for {
if m.cancel.Load() {
return
}
func (m *TokensBucket) cleanupRoutine() {
ticker := time.NewTicker(m.checkTokenTime)
defer ticker.Stop()

m.tokens.Range(func(key, value any) bool {
id, token := key.(string), value.(*token)
if time.Since(token.lastUse.Load().(time.Time)) >= m.unusedTokenTime {
m.tokens.Delete(id)
for {
select {
case <-ticker.C:
if m.cancel.Load() {
return
}
return false
})

m.sleep.Sleep(m.checkTokenTime)
now := time.Now().UnixNano()
for _, shard := range m.shards {
shard.mu.Lock()
for id, token := range shard.tokens {
if now-token.lastUse >= m.unusedTokenTime {
delete(shard.tokens, id)
}
}
shard.mu.Unlock()
}
}
}
}

func destructBucket(m *TokensBucket) {
m.cancel.Store(true)
}

type noopLimiter struct{}

func (l *noopLimiter) Obtain(string) {}

// sleeper sleeps. This thing is necessary for tests.
type sleeper interface {
Sleep(time.Duration)
}
Expand Down
25 changes: 18 additions & 7 deletions v1/rate_limit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,22 @@ func (t *TokensBucketTest) Test_NewTokensBucket() {

func (t *TokensBucketTest) new(
maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration, sleeper sleeper) *TokensBucket {
shardCount := uint32(runtime.NumCPU() * 2) // Use double the CPU count for sharding
shards := make([]*tokenShard, shardCount)
for i := range shards {
shards[i] = &tokenShard{tokens: make(map[string]*token)}
}

bucket := &TokensBucket{
maxRPS: maxRPS,
unusedTokenTime: unusedTokenTime,
unusedTokenTime: unusedTokenTime.Nanoseconds(),
checkTokenTime: checkTokenTime,
shards: shards,
shardCount: shardCount,
sleep: sleeper,
}
runtime.SetFinalizer(bucket, destructBasket)

runtime.SetFinalizer(bucket, destructBucket)
return bucket
}

Expand All @@ -46,12 +55,14 @@ func (t *TokensBucketTest) Test_Obtain_NoThrottle() {
func (t *TokensBucketTest) Test_Obtain_Sleep() {
clock := &fakeSleeper{}
tb := t.new(100, time.Hour, time.Minute, clock)
_, exists := tb.getShard("w").tokens["w"]
t.Require().False(exists)

var wg sync.WaitGroup
wg.Add(1)
go func() {
for i := 0; i < 301; i++ {
tb.Obtain("a")
tb.Obtain("w")
}
wg.Done()
}()
Expand All @@ -63,15 +74,15 @@ func (t *TokensBucketTest) Test_Obtain_Sleep() {
func (t *TokensBucketTest) Test_Obtain_AddRPS() {
clock := clockwork.NewFakeClock()
tb := t.new(100, time.Hour, time.Minute, clock)
go tb.deleteUnusedToken()
go tb.cleanupRoutine()
tb.Obtain("a")
clock.Advance(time.Minute * 2)

item, found := tb.tokens.Load("a")
item, found := tb.getShard("a").tokens["a"]
t.Require().True(found)
t.Assert().Equal(1, int(item.(*token).rps.Load()))
t.Assert().Equal(1, int(item.rps))
tb.Obtain("a")
t.Assert().Equal(2, int(item.(*token).rps.Load()))
t.Assert().Equal(2, int(item.rps))
}

type fakeSleeper struct {
Expand Down
Loading