diff --git a/pkg/protocol/http1/req/request.go b/pkg/protocol/http1/req/request.go index f0107476e..c6c5862ad 100644 --- a/pkg/protocol/http1/req/request.go +++ b/pkg/protocol/http1/req/request.go @@ -171,7 +171,7 @@ func write(req *protocol.Request, w network.Writer, usingProxy bool) error { return errRequestHostRequired } - if len(req.Header.Host()) == 0 { + if len(req.Header.Host()) == 0 || req.UseURIHost { req.Header.SetHostBytes(host) } diff --git a/pkg/protocol/http1/req/request_test.go b/pkg/protocol/http1/req/request_test.go index 0411187a5..0506d55c2 100644 --- a/pkg/protocol/http1/req/request_test.go +++ b/pkg/protocol/http1/req/request_test.go @@ -1487,3 +1487,50 @@ func testRequestBodyStreamWithTrailer(t *testing.T, body []byte, disableNormaliz } } } + +func TestURIHostPriority(t *testing.T) { + t.Parallel() + + // normal case + var req protocol.Request + req.Header.SetHost("foobar.com") + req.SetRequestURI("http://foobarhost.com") + req.ParseURI() + var w bytes.Buffer + zw := netpoll.NewWriter(&w) + if err := Write(&req, zw); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := zw.Flush(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + var req1 protocol.Request + zr := mock.NewZeroCopyReader(w.String()) + if err := Read(&req1, zr); err != nil { + t.Fatalf("unexpected error: %s", err) + } + assert.DeepEqual(t, "foobar.com", string(req1.Host())) + + // uri higher priority case + var reqURIHighPriority protocol.Request + reqURIHighPriority.Header.SetHost("foobar.com") + reqURIHighPriority.SetRequestURI("http://foobarhost.com") + reqURIHighPriority.ParseURI() + reqURIHighPriority.UseURIHost = true + var bw bytes.Buffer + zw = netpoll.NewWriter(&bw) + if err := Write(&reqURIHighPriority, zw); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := zw.Flush(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + var req1URIHighPriority protocol.Request + zr = mock.NewZeroCopyReader(bw.String()) + if err := Read(&req1URIHighPriority, zr); err != nil { + t.Fatalf("unexpected error: %s", err) + } + assert.DeepEqual(t, "foobarhost.com", string(req1URIHighPriority.Host())) +} diff --git a/pkg/protocol/request.go b/pkg/protocol/request.go index 8e4b40bf7..53fe2c99a 100644 --- a/pkg/protocol/request.go +++ b/pkg/protocol/request.go @@ -109,6 +109,9 @@ type Request struct { multipartFiles []*File multipartFields []*MultipartField + // UseURIHost uses URI host as host header. Ignore origin host header + UseURIHost bool + // Request level options, service discovery options etc. options *config.RequestOptions } @@ -190,6 +193,7 @@ func (req *Request) resetSkipHeaderAndConn() { req.parsedURI = false req.parsedPostArgs = false req.postArgs.Reset() + req.UseURIHost = false } func (req *Request) ResetSkipHeader() {