From 9e940c58e4e7660b1151f22bfaae211abd10b04c Mon Sep 17 00:00:00 2001 From: Oleg Kovalov Date: Tue, 30 Nov 2021 16:28:19 +0100 Subject: [PATCH] Add IsHedgedRequest (#19) --- hedged.go | 15 +++++++++++++-- hedged_test.go | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/hedged.go b/hedged.go index f3a3f04..17d0663 100644 --- a/hedged.go +++ b/hedged.go @@ -107,7 +107,7 @@ func (ht *hedgedTransport) RoundTrip(req *http.Request) (*http.Response, error) for sent := 0; len(errOverall.Errors) < ht.upto; sent++ { if sent < ht.upto { idx := sent - subReq, cancel := reqWithCtx(req, mainCtx) + subReq, cancel := reqWithCtx(req, mainCtx, idx != 0) cancels[idx] = cancel runInPool(func() { @@ -174,12 +174,23 @@ type indexedResp struct { Resp *http.Response } -func reqWithCtx(r *http.Request, ctx context.Context) (*http.Request, func()) { +func reqWithCtx(r *http.Request, ctx context.Context, isHedged bool) (*http.Request, func()) { ctx, cancel := context.WithCancel(ctx) + if isHedged { + ctx = context.WithValue(ctx, hedgedRequest{}, struct{}{}) + } req := r.WithContext(ctx) return req, cancel } +type hedgedRequest struct{} + +// IsHedgedRequest reports when a request is hedged. +func IsHedgedRequest(r *http.Request) bool { + val := r.Context().Value(hedgedRequest{}) + return val != nil +} + // atomicCounter is a false sharing safe counter. type atomicCounter struct { count uint64 diff --git a/hedged_test.go b/hedged_test.go index d5b15c5..16f5557 100644 --- a/hedged_test.go +++ b/hedged_test.go @@ -3,6 +3,7 @@ package hedgedhttp_test import ( "bytes" "context" + "errors" "fmt" "io" "io/ioutil" @@ -425,6 +426,42 @@ func TestCancelByClient(t *testing.T) { } } +func TestIsHedged(t *testing.T) { + var gotRequests int + + rt := testRoundTripper(func(req *http.Request) (*http.Response, error) { + if gotRequests == 0 { + if hedgedhttp.IsHedgedRequest(req) { + t.Fatal("first request is hedged") + } + } else { + if !hedgedhttp.IsHedgedRequest(req) { + t.Fatalf("%d request is not hedged", gotRequests) + } + } + gotRequests++ + return nil, errors.New("just an error") + }) + + req, err := http.NewRequest("GET", "http://no-matter-what", http.NoBody) + if err != nil { + t.Fatal(err) + } + + const upto = 7 + _, _ = hedgedhttp.NewRoundTripper(10*time.Millisecond, upto, rt).RoundTrip(req) + + if gotRequests != upto { + t.Fatalf("want %v, got %v", upto, gotRequests) + } +} + +type testRoundTripper func(req *http.Request) (*http.Response, error) + +func (t testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return t(req) +} + func checkAllMetricsAreZero(t *testing.T, metrics *hedgedhttp.Stats) { expectExactMetricsAndSnapshot(t, metrics, hedgedhttp.StatsSnapshot{}) }