diff --git a/pkg/server/service/loadbalancer/wrr/wrr.go b/pkg/server/service/loadbalancer/wrr/wrr.go index 720945ffa3..009835fbf6 100644 --- a/pkg/server/service/loadbalancer/wrr/wrr.go +++ b/pkg/server/service/loadbalancer/wrr/wrr.go @@ -4,12 +4,11 @@ import ( "container/heap" "context" "errors" - "hash/fnv" "net/http" - "strconv" "sync" "github.com/traefik/traefik/v2/pkg/config/dynamic" + "github.com/traefik/traefik/v2/pkg/config/runtime" "github.com/traefik/traefik/v2/pkg/log" ) @@ -17,7 +16,9 @@ type namedHandler struct { http.Handler name string weight float64 - deadline float64 + pending uint64 + healthy bool + queueIdx int } type stickyCookie struct { @@ -34,27 +35,20 @@ type stickyCookie struct { type Balancer struct { stickyCookie *stickyCookie wantsHealthCheck bool - - handlersMu sync.RWMutex - // References all the handlers by name and also by the hashed value of the name. - handlerMap map[string]*namedHandler - handlers []*namedHandler - curDeadline float64 - // status is a record of which child services of the Balancer are healthy, keyed - // by name of child service. A service is initially added to the map when it is - // created via Add, and it is later removed or added to the map as needed, - // through the SetStatus method. - status map[string]struct{} // updaters is the list of hooks that are run (to update the Balancer // parent(s)), whenever the Balancer status changes. updaters []func(bool) + + mutex sync.RWMutex + enabledHandlers priorityQueue + handlersByName map[string]*namedHandler + healthyCount int } // New creates a new load balancer. func New(sticky *dynamic.Sticky, wantHealthCheck bool) *Balancer { balancer := &Balancer{ - status: make(map[string]struct{}), - handlerMap: make(map[string]*namedHandler), + handlersByName: make(map[string]*namedHandler), wantsHealthCheck: wantHealthCheck, } if sticky != nil && sticky.Cookie != nil { @@ -64,78 +58,53 @@ func New(sticky *dynamic.Sticky, wantHealthCheck bool) *Balancer { httpOnly: sticky.Cookie.HTTPOnly, } } - return balancer } -// Len implements heap.Interface/sort.Interface. -func (b *Balancer) Len() int { return len(b.handlers) } - -// Less implements heap.Interface/sort.Interface. -func (b *Balancer) Less(i, j int) bool { - return b.handlers[i].deadline < b.handlers[j].deadline -} - -// Swap implements heap.Interface/sort.Interface. -func (b *Balancer) Swap(i, j int) { - b.handlers[i], b.handlers[j] = b.handlers[j], b.handlers[i] -} - -// Push implements heap.Interface for pushing an item into the heap. -func (b *Balancer) Push(x interface{}) { - h, ok := x.(*namedHandler) - if !ok { - return - } - - b.handlers = append(b.handlers, h) -} - -// Pop implements heap.Interface for popping an item from the heap. -// It panics if b.Len() < 1. -func (b *Balancer) Pop() interface{} { - h := b.handlers[len(b.handlers)-1] - b.handlers = b.handlers[0 : len(b.handlers)-1] - return h -} - // SetStatus sets on the balancer that its given child is now of the given -// status. balancerName is only needed for logging purposes. -func (b *Balancer) SetStatus(ctx context.Context, childName string, up bool) { - b.handlersMu.Lock() - defer b.handlersMu.Unlock() - - upBefore := len(b.status) > 0 - - status := "DOWN" - if up { - status = "UP" - } - log.FromContext(ctx).Debugf("Setting status of %s to %v", childName, status) - if up { - b.status[childName] = struct{}{} - } else { - delete(b.status, childName) +// status. +func (b *Balancer) SetStatus(ctx context.Context, childName string, healthy bool) { + log.FromContext(ctx).Debugf("Setting status of %s to %v", childName, statusAsStr(healthy)) + + b.mutex.Lock() + nh := b.handlersByName[childName] + if nh == nil { + b.mutex.Unlock() + return } - upAfter := len(b.status) > 0 - status = "DOWN" - if upAfter { - status = "UP" + healthyBefore := b.healthyCount > 0 + if nh.healthy != healthy { + nh.healthy = healthy + if healthy { + b.healthyCount++ + b.enabledHandlers.push(nh) + } else { + b.healthyCount-- + } } + healthyAfter := b.healthyCount > 0 + b.mutex.Unlock() // No Status Change - if upBefore == upAfter { + if healthyBefore == healthyAfter { // We're still with the same status, no need to propagate - log.FromContext(ctx).Debugf("Still %s, no need to propagate", status) + log.FromContext(ctx).Debugf("Still %s, no need to propagate", statusAsStr(healthyBefore)) return } // Status Change - log.FromContext(ctx).Debugf("Propagating new %s status", status) + log.FromContext(ctx).Debugf("Propagating new %s status", statusAsStr(healthyAfter)) for _, fn := range b.updaters { - fn(upAfter) + fn(healthyAfter) + } +} + +func statusAsStr(healthy bool) string { + if healthy { + return runtime.StatusUp } + return runtime.StatusDown } // RegisterStatusUpdater adds fn to the list of hooks that are run when the @@ -151,59 +120,61 @@ func (b *Balancer) RegisterStatusUpdater(fn func(up bool)) error { var errNoAvailableServer = errors.New("no available server") -func (b *Balancer) nextServer() (*namedHandler, error) { - b.handlersMu.Lock() - defer b.handlersMu.Unlock() - - if len(b.handlers) == 0 || len(b.status) == 0 { - return nil, errNoAvailableServer +func (b *Balancer) acquireHandler(preferredName string) (*namedHandler, error) { + b.mutex.Lock() + defer b.mutex.Unlock() + var nh *namedHandler + // Check the preferred handler fist if provided. + if preferredName != "" { + nh = b.handlersByName[preferredName] + if nh != nil && nh.healthy { + nh.pending++ + b.enabledHandlers.fix(nh) + return nh, nil + } } - - var handler *namedHandler + // Pick the handler with the least number of pending requests. for { - // Pick handler with closest deadline. - handler = heap.Pop(b).(*namedHandler) - - // curDeadline should be handler's deadline so that new added entry would have a fair competition environment with the old ones. - b.curDeadline = handler.deadline - handler.deadline += 1 / handler.weight - - heap.Push(b, handler) - if _, ok := b.status[handler.name]; ok { - break + nh = b.enabledHandlers.pop() + if nh == nil { + return nil, errNoAvailableServer + } + // If the handler is marked as unhealthy, then continue with the next + // best option. It will be put back into the priority queue once its + // status changes to healthy. + if !nh.healthy { + continue } + // Otherwise increment the number of pending requests, put it back into + // the priority queue, and return it as a selected for the request. + nh.pending++ + b.enabledHandlers.push(nh) + log.WithoutContext().Debugf("Service selected by WRR: %s", nh.name) + return nh, nil } +} - log.WithoutContext().Debugf("Service selected by WRR: %s", handler.name) - return handler, nil +func (b *Balancer) releaseHandler(nh *namedHandler) { + b.mutex.Lock() + defer b.mutex.Unlock() + nh.pending-- + if nh.healthy { + b.enabledHandlers.fix(nh) + } } func (b *Balancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + var preferredName string if b.stickyCookie != nil { cookie, err := req.Cookie(b.stickyCookie.name) - if err != nil && !errors.Is(err, http.ErrNoCookie) { log.WithoutContext().Warnf("Error while reading cookie: %v", err) } - if err == nil && cookie != nil { - b.handlersMu.RLock() - handler, ok := b.handlerMap[cookie.Value] - b.handlersMu.RUnlock() - - if ok && handler != nil { - b.handlersMu.RLock() - _, isHealthy := b.status[handler.name] - b.handlersMu.RUnlock() - if isHealthy { - handler.ServeHTTP(w, req) - return - } - } + preferredName = cookie.Value } } - - server, err := b.nextServer() + nh, err := b.acquireHandler(preferredName) if err != nil { if errors.Is(err, errNoAvailableServer) { http.Error(w, errNoAvailableServer.Error(), http.StatusServiceUnavailable) @@ -214,11 +185,18 @@ func (b *Balancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if b.stickyCookie != nil { - cookie := &http.Cookie{Name: b.stickyCookie.name, Value: hash(server.name), Path: "/", HttpOnly: b.stickyCookie.httpOnly, Secure: b.stickyCookie.secure} + cookie := &http.Cookie{ + Name: b.stickyCookie.name, + Value: nh.name, + Path: "/", + HttpOnly: b.stickyCookie.httpOnly, + Secure: b.stickyCookie.secure, + } http.SetCookie(w, cookie) } - server.ServeHTTP(w, req) + nh.ServeHTTP(w, req) + b.releaseHandler(nh) } // Add adds a handler. @@ -233,21 +211,67 @@ func (b *Balancer) Add(name string, handler http.Handler, weight *int) { return } - h := &namedHandler{Handler: handler, name: name, weight: float64(w)} + nh := &namedHandler{ + Handler: handler, + name: name, + weight: float64(w), + pending: 1, + healthy: true, + } + b.mutex.Lock() + b.enabledHandlers.push(nh) + b.handlersByName[nh.name] = nh + b.healthyCount++ + b.mutex.Unlock() +} + +type priorityQueue struct { + heap []*namedHandler +} - b.handlersMu.Lock() - h.deadline = b.curDeadline + 1/h.weight - heap.Push(b, h) - b.status[name] = struct{}{} - b.handlerMap[name] = h - b.handlerMap[hash(name)] = h - b.handlersMu.Unlock() +func (pq *priorityQueue) push(nh *namedHandler) { + heap.Push(pq, nh) } -func hash(input string) string { - hasher := fnv.New64() - // We purposely ignore the error because the implementation always returns nil. - _, _ = hasher.Write([]byte(input)) +func (pq *priorityQueue) pop() *namedHandler { + if len(pq.heap) < 1 { + return nil + } + return heap.Pop(pq).(*namedHandler) +} - return strconv.FormatUint(hasher.Sum64(), 16) +func (pq *priorityQueue) fix(nh *namedHandler) { + heap.Fix(pq, nh.queueIdx) +} + +// Len implements heap.Interface/sort.Interface. +func (pq *priorityQueue) Len() int { return len(pq.heap) } + +// Less implements heap.Interface/sort.Interface. +func (pq *priorityQueue) Less(i, j int) bool { + nhi, nhj := pq.heap[i], pq.heap[j] + return float64(nhi.pending)/nhi.weight < float64(nhj.pending)/nhj.weight +} + +// Swap implements heap.Interface/sort.Interface. +func (pq *priorityQueue) Swap(i, j int) { + pq.heap[i], pq.heap[j] = pq.heap[j], pq.heap[i] + pq.heap[i].queueIdx = i + pq.heap[j].queueIdx = j +} + +// Push implements heap.Interface for pushing an item into the heap. +func (pq *priorityQueue) Push(x interface{}) { + nh := x.(*namedHandler) + nh.queueIdx = len(pq.heap) + pq.heap = append(pq.heap, nh) +} + +// Pop implements heap.Interface for popping an item from the heap. +// It panics if b.Len() < 1. +func (pq *priorityQueue) Pop() interface{} { + lastIdx := len(pq.heap) - 1 + nh := pq.heap[lastIdx] + pq.heap = pq.heap[0:lastIdx] + return nh } diff --git a/pkg/server/service/loadbalancer/wrr/wrr_test.go b/pkg/server/service/loadbalancer/wrr/wrr_test.go index 19f2cf38ef..e08fc9a3a3 100644 --- a/pkg/server/service/loadbalancer/wrr/wrr_test.go +++ b/pkg/server/service/loadbalancer/wrr/wrr_test.go @@ -2,322 +2,205 @@ package wrr import ( "context" + "fmt" "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" - "github.com/traefik/traefik/v2/pkg/config/dynamic" + "github.com/stretchr/testify/require" ) -func TestBalancer(t *testing.T) { - balancer := New(nil, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "first") - rw.WriteHeader(http.StatusOK) - }), Int(3)) - - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "second") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 4 { - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - } +const ( + handlerAny = "" +) - assert.Equal(t, 3, recorder.save["first"]) - assert.Equal(t, 1, recorder.save["second"]) +func TestBalancerWeights(t *testing.T) { + b := New(nil, false) + addDummyHandler(b, "A", 3) + addDummyHandler(b, "B", 1) + + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 1, "B": 0}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 2, "B": 0}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 2, "B": 1}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 3, "B": 1}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 4, "B": 1}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 5, "B": 1}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 5, "B": 2}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 6, "B": 2}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 7, "B": 2}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 8, "B": 2}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 8, "B": 3}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 9, "B": 3}) + assertRelease(t, b, "B", map[string]int{"A": 9, "B": 2}) + assertRelease(t, b, "B", map[string]int{"A": 9, "B": 1}) + assertRelease(t, b, "B", map[string]int{"A": 9, "B": 0}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 9, "B": 1}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 9, "B": 2}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 9, "B": 3}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 10, "B": 3}) } -func TestBalancerNoService(t *testing.T) { - balancer := New(nil, false) - - recorder := httptest.NewRecorder() - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - - assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode) +func TestBalancerUpAndDown(t *testing.T) { + b := New(nil, false) + addDummyHandler(b, "A", 1) + addDummyHandler(b, "B", 1) + + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 1, "B": 0}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 1, "B": 1}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 2, "B": 1}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 2, "B": 2}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 3, "B": 2}) + b.SetStatus(context.Background(), "B", false) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 4, "B": 2}) + b.SetStatus(context.Background(), "B", false) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 5, "B": 2}) + b.SetStatus(context.Background(), "A", false) + _, err := b.acquireHandler(handlerAny) + assert.Equal(t, errNoAvailableServer, err) + assertRelease(t, b, "B", map[string]int{"A": 5, "B": 1}) + assertRelease(t, b, "A", map[string]int{"A": 4, "B": 1}) + assertRelease(t, b, "A", map[string]int{"A": 3, "B": 1}) + _, err = b.acquireHandler(handlerAny) + assert.Equal(t, errNoAvailableServer, err) + b.SetStatus(context.Background(), "A", true) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 4, "B": 1}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 5, "B": 1}) + b.SetStatus(context.Background(), "B", true) + b.SetStatus(context.Background(), "B", true) + b.SetStatus(context.Background(), "A", true) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 5, "B": 2}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 5, "B": 3}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 5, "B": 4}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 5, "B": 5}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 6, "B": 5}) } -func TestBalancerOneServerZeroWeight(t *testing.T) { - balancer := New(nil, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "first") - rw.WriteHeader(http.StatusOK) - }), Int(1)) +func TestBalancerZeroWeight(t *testing.T) { + b := New(nil, false) + addDummyHandler(b, "A", 0) + addDummyHandler(b, "B", 1) - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 3 { - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - } - - assert.Equal(t, 3, recorder.save["first"]) + assertAcquire(t, b, handlerAny, "B", map[string]int{"B": 1}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"B": 2}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"B": 3}) } -type key string - -const serviceName key = "serviceName" - -func TestBalancerNoServiceUp(t *testing.T) { - balancer := New(nil, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.WriteHeader(http.StatusInternalServerError) - }), Int(1)) - - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.WriteHeader(http.StatusInternalServerError) - }), Int(1)) - - balancer.SetStatus(context.WithValue(context.Background(), serviceName, "parent"), "first", false) - balancer.SetStatus(context.WithValue(context.Background(), serviceName, "parent"), "second", false) - - recorder := httptest.NewRecorder() - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - - assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode) +func TestBalancerPropagate(t *testing.T) { + b := New(nil, true) + addDummyHandler(b, "A", 1) + addDummyHandler(b, "B", 1) + updates := []bool{} + err := b.RegisterStatusUpdater(func(healthy bool) { + updates = append(updates, healthy) + }) + require.NoError(t, err) + + b.SetStatus(context.Background(), "A", false) + assert.Equal(t, []bool{}, updates) + b.SetStatus(context.Background(), "A", false) + assert.Equal(t, []bool{}, updates) + b.SetStatus(context.Background(), "B", false) + assert.Equal(t, []bool{false}, updates) + b.SetStatus(context.Background(), "A", false) + assert.Equal(t, []bool{false}, updates) + b.SetStatus(context.Background(), "B", false) + assert.Equal(t, []bool{false}, updates) + b.SetStatus(context.Background(), "B", true) + assert.Equal(t, []bool{false, true}, updates) + b.SetStatus(context.Background(), "B", true) + assert.Equal(t, []bool{false, true}, updates) + b.SetStatus(context.Background(), "A", true) + assert.Equal(t, []bool{false, true}, updates) + b.SetStatus(context.Background(), "A", false) + assert.Equal(t, []bool{false, true}, updates) + b.SetStatus(context.Background(), "A", false) + assert.Equal(t, []bool{false, true}, updates) + b.SetStatus(context.Background(), "B", false) + assert.Equal(t, []bool{false, true, false}, updates) } -func TestBalancerOneServerDown(t *testing.T) { - balancer := New(nil, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "first") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.WriteHeader(http.StatusInternalServerError) - }), Int(1)) - balancer.SetStatus(context.WithValue(context.Background(), serviceName, "parent"), "second", false) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 3 { - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - } - - assert.Equal(t, 3, recorder.save["first"]) +func TestBalancerSticky(t *testing.T) { + b := New(nil, false) + addDummyHandler(b, "A", 1) + addDummyHandler(b, "B", 1) + + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 1, "B": 0}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 1, "B": 1}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 2, "B": 1}) + assertAcquire(t, b, "A", "A", map[string]int{"A": 3, "B": 1}) + assertAcquire(t, b, "A", "A", map[string]int{"A": 4, "B": 1}) + assertAcquire(t, b, "A", "A", map[string]int{"A": 5, "B": 1}) + b.SetStatus(context.Background(), "A", false) + // Even though A is preferred B is allocated when A is not available. + assertAcquire(t, b, "A", "B", map[string]int{"A": 5, "B": 2}) + assertAcquire(t, b, "A", "B", map[string]int{"A": 5, "B": 3}) + b.SetStatus(context.Background(), "A", true) + assertAcquire(t, b, "A", "A", map[string]int{"A": 6, "B": 3}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 6, "B": 4}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 6, "B": 5}) + assertAcquire(t, b, handlerAny, "B", map[string]int{"A": 6, "B": 6}) + assertAcquire(t, b, handlerAny, "A", map[string]int{"A": 7, "B": 6}) } -func TestBalancerDownThenUp(t *testing.T) { - balancer := New(nil, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "first") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "second") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - balancer.SetStatus(context.WithValue(context.Background(), serviceName, "parent"), "second", false) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 3 { - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - } - assert.Equal(t, 3, recorder.save["first"]) - - balancer.SetStatus(context.WithValue(context.Background(), serviceName, "parent"), "second", true) - recorder = &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 2 { - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) +// When sticky sessions are allocated that does not mess up selection order. +// Internally heap is used and sticky allocation has to maintain correct +// ordering of handlers in the priority queue. +func TestBalancerMany(t *testing.T) { + b := New(nil, false) + for _, handlerName := range "ABCDEFGH" { + addDummyHandler(b, fmt.Sprintf("%c", handlerName), 1) } - assert.Equal(t, 1, recorder.save["first"]) - assert.Equal(t, 1, recorder.save["second"]) -} - -func TestBalancerPropagate(t *testing.T) { - balancer1 := New(nil, true) - - balancer1.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "first") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - balancer1.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "second") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - - balancer2 := New(nil, true) - balancer2.Add("third", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "third") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - balancer2.Add("fourth", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "fourth") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - - topBalancer := New(nil, true) - topBalancer.Add("balancer1", balancer1, Int(1)) - _ = balancer1.RegisterStatusUpdater(func(up bool) { - topBalancer.SetStatus(context.WithValue(context.Background(), serviceName, "top"), "balancer1", up) - // TODO(mpl): if test gets flaky, add channel or something here to signal that - // propagation is done, and wait on it before sending request. - }) - topBalancer.Add("balancer2", balancer2, Int(1)) - _ = balancer2.RegisterStatusUpdater(func(up bool) { - topBalancer.SetStatus(context.WithValue(context.Background(), serviceName, "top"), "balancer2", up) - }) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 8 { - topBalancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) + for i := 0; i < 100; i++ { + _, err := b.acquireHandler(handlerAny) + require.NoError(t, err) } - assert.Equal(t, 2, recorder.save["first"]) - assert.Equal(t, 2, recorder.save["second"]) - assert.Equal(t, 2, recorder.save["third"]) - assert.Equal(t, 2, recorder.save["fourth"]) - wantStatus := []int{200, 200, 200, 200, 200, 200, 200, 200} - assert.Equal(t, wantStatus, recorder.status) - - // fourth gets downed, but balancer2 still up since third is still up. - balancer2.SetStatus(context.WithValue(context.Background(), serviceName, "top"), "fourth", false) - recorder = &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 8 { - topBalancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) + assert.Equal(t, map[string]int{"A": 13, "B": 13, "C": 12, "D": 13, "E": 12, "F": 12, "G": 12, "H": 13}, pendingCounts(b)) + for i := 0; i < 10; i++ { + _, err := b.acquireHandler("D") + require.NoError(t, err) } - assert.Equal(t, 2, recorder.save["first"]) - assert.Equal(t, 2, recorder.save["second"]) - assert.Equal(t, 4, recorder.save["third"]) - assert.Equal(t, 0, recorder.save["fourth"]) - wantStatus = []int{200, 200, 200, 200, 200, 200, 200, 200} - assert.Equal(t, wantStatus, recorder.status) - - // third gets downed, and the propagation triggers balancer2 to be marked as - // down as well for topBalancer. - balancer2.SetStatus(context.WithValue(context.Background(), serviceName, "top"), "third", false) - recorder = &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - for range 8 { - topBalancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) + assert.Equal(t, map[string]int{"A": 13, "B": 13, "C": 12, "D": 23, "E": 12, "F": 12, "G": 12, "H": 13}, pendingCounts(b)) + for i := 0; i < 74; i++ { + _, err := b.acquireHandler(handlerAny) + require.NoError(t, err) } - assert.Equal(t, 4, recorder.save["first"]) - assert.Equal(t, 4, recorder.save["second"]) - assert.Equal(t, 0, recorder.save["third"]) - assert.Equal(t, 0, recorder.save["fourth"]) - wantStatus = []int{200, 200, 200, 200, 200, 200, 200, 200} - assert.Equal(t, wantStatus, recorder.status) -} - -func TestBalancerAllServersZeroWeight(t *testing.T) { - balancer := New(nil, false) - - balancer.Add("test", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) - balancer.Add("test2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0)) - - recorder := httptest.NewRecorder() - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) - - assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode) -} - -func TestSticky(t *testing.T) { - balancer := New(&dynamic.Sticky{ - Cookie: &dynamic.Cookie{Name: "test"}, - }, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "first") - rw.WriteHeader(http.StatusOK) - }), Int(1)) - - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "second") - rw.WriteHeader(http.StatusOK) - }), Int(2)) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - - req := httptest.NewRequest(http.MethodGet, "/", nil) - for range 3 { - for _, cookie := range recorder.Result().Cookies() { - assert.NotContains(t, "test=first", cookie.Value) - assert.NotContains(t, "test=second", cookie.Value) - req.AddCookie(cookie) - } - recorder.ResponseRecorder = httptest.NewRecorder() - - balancer.ServeHTTP(recorder, req) + assert.Equal(t, map[string]int{"A": 23, "B": 23, "C": 23, "D": 23, "E": 23, "F": 23, "G": 23, "H": 23}, pendingCounts(b)) + for i := 0; i < 8; i++ { + _, err := b.acquireHandler(handlerAny) + require.NoError(t, err) } - - assert.Equal(t, 0, recorder.save["first"]) - assert.Equal(t, 3, recorder.save["second"]) + assert.Equal(t, map[string]int{"A": 24, "B": 24, "C": 24, "D": 24, "E": 24, "F": 24, "G": 24, "H": 24}, pendingCounts(b)) } -func TestSticky_FallBack(t *testing.T) { - balancer := New(&dynamic.Sticky{ - Cookie: &dynamic.Cookie{Name: "test"}, - }, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "first") +func addDummyHandler(b *Balancer, handlerName string, weight int) { + h := func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("server", handlerName) rw.WriteHeader(http.StatusOK) - }), Int(1)) - - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "second") - rw.WriteHeader(http.StatusOK) - }), Int(2)) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.AddCookie(&http.Cookie{Name: "test", Value: "second"}) - for range 3 { - recorder.ResponseRecorder = httptest.NewRecorder() - - balancer.ServeHTTP(recorder, req) } - - assert.Equal(t, 0, recorder.save["first"]) - assert.Equal(t, 3, recorder.save["second"]) + b.Add(handlerName, http.HandlerFunc(h), &weight) } -// TestBalancerBias makes sure that the WRR algorithm spreads elements evenly right from the start, -// and that it does not "over-favor" the high-weighted ones with a biased start-up regime. -func TestBalancerBias(t *testing.T) { - balancer := New(nil, false) - - balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "A") - rw.WriteHeader(http.StatusOK) - }), Int(11)) - - balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("server", "B") - rw.WriteHeader(http.StatusOK) - }), Int(3)) - - recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} - - for i := 0; i < 14; i++ { - balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) +func pendingCounts(b *Balancer) map[string]int { + countsByName := make(map[string]int) + b.mutex.Lock() + for handlerName, handler := range b.handlersByName { + countsByName[handlerName] = int(handler.pending) - 1 } - - wantSequence := []string{"A", "A", "A", "B", "A", "A", "A", "A", "B", "A", "A", "A", "B", "A"} - - assert.Equal(t, wantSequence, recorder.sequence) + b.mutex.Unlock() + return countsByName } -func Int(v int) *int { return &v } - -type responseRecorder struct { - *httptest.ResponseRecorder - save map[string]int - sequence []string - status []int +func assertAcquire(t *testing.T, b *Balancer, preferredName, acquiredName string, want map[string]int) { + nh, err := b.acquireHandler(preferredName) + require.NoError(t, err) + assert.Equal(t, acquiredName, nh.name) + assert.Equal(t, want, pendingCounts(b)) } -func (r *responseRecorder) WriteHeader(statusCode int) { - r.save[r.Header().Get("server")]++ - r.sequence = append(r.sequence, r.Header().Get("server")) - r.status = append(r.status, statusCode) - r.ResponseRecorder.WriteHeader(statusCode) +func assertRelease(t *testing.T, b *Balancer, acquiredName string, want map[string]int) { + b.mutex.Lock() + nh := b.handlersByName[acquiredName] + b.mutex.Unlock() + b.releaseHandler(nh) + assert.Equal(t, want, pendingCounts(b)) }