diff --git a/go.mod b/go.mod index d2950316f..31972f936 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/cloudwego/netpoll v0.6.4 github.com/fsnotify/fsnotify v1.5.4 github.com/nyaruka/phonenumbers v1.0.55 + github.com/stretchr/testify v1.8.1 github.com/tidwall/gjson v1.14.4 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.24.0 @@ -19,14 +20,18 @@ require ( github.com/bytedance/sonic/loader v0.2.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/protobuf v1.5.0 // indirect github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d // indirect github.com/smartystreets/goconvey v1.6.4 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 7e920d459..fe72a680c 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,7 @@ github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIK github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -78,6 +79,7 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/protocol/client/client.go b/pkg/protocol/client/client.go index 777f55cdd..a5b2c88e8 100644 --- a/pkg/protocol/client/client.go +++ b/pkg/protocol/client/client.go @@ -242,6 +242,9 @@ func DoRequestFollowRedirects(ctx context.Context, req *protocol.Request, resp * break } url = getRedirectURL(url, location) + + // Remove the former host header. + req.Header.Del(consts.HeaderHost) } return statusCode, body, err diff --git a/pkg/protocol/client/client_test.go b/pkg/protocol/client/client_test.go new file mode 100644 index 000000000..49e7e7df6 --- /dev/null +++ b/pkg/protocol/client/client_test.go @@ -0,0 +1,73 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package client + +import ( + "context" + "errors" + "testing" + + "github.com/cloudwego/hertz/internal/bytestr" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var firstTime = true + +type MockDoer struct { + mock.Mock +} + +func (m *MockDoer) Do(ctx context.Context, req *protocol.Request, resp *protocol.Response) error { + + // this is the real logic in (c *HostClient) doNonNilReqResp method + if len(req.Header.Host()) == 0 { + req.Header.SetHostBytes(req.URI().Host()) + } + + if firstTime { + // req.Header.Host() is the real host writing to the wire + if string(req.Header.Host()) != "example.com" { + return errors.New("host not match") + } + // this is the real logic in (c *HostClient) doNonNilReqResp method + if len(req.Header.Host()) == 0 { + req.Header.SetHostBytes(req.URI().Host()) + } + resp.Header.SetCanonical(bytestr.StrLocation, []byte("https://a.b.c/foo")) + resp.SetStatusCode(301) + firstTime = false + return nil + } + + if string(req.Header.Host()) != "a.b.c" { + resp.SetStatusCode(400) + return errors.New("host not match") + } + + resp.SetStatusCode(200) + + return nil +} + +func TestDoRequestFollowRedirects(t *testing.T) { + mockDoer := new(MockDoer) + mockDoer.On("Do", mock.Anything, mock.Anything, mock.Anything).Return(nil) + statusCode, _, err := DoRequestFollowRedirects(context.Background(), &protocol.Request{}, &protocol.Response{}, "https://example.com", defaultMaxRedirectsCount, mockDoer) + assert.NoError(t, err) + assert.Equal(t, 200, statusCode) +}