diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go index c9332a2ec..c3fc66300 100644 --- a/pkg/app/client/client_test.go +++ b/pkg/app/client/client_test.go @@ -57,6 +57,8 @@ import ( "path/filepath" "reflect" "regexp" + "runtime" + "runtime/debug" "strconv" "strings" "sync" @@ -277,6 +279,83 @@ func TestClientPostBodyStream(t *testing.T) { assert.DeepEqual(t, "a="+v, string(body)) } +func mb(byteSize uint64) float32 { + return float32(byteSize) / float32(1024*1024) +} + +func TestBodystreamReleaseMem(t *testing.T) { + debug.SetGCPercent(-1) + defer debug.SetGCPercent(100) + + var ms runtime.MemStats + runtime.ReadMemStats(&ms) + + // start reverse proxy backend engine + rpopt := config.NewOptions([]config.Option{}) + rpopt.Addr = nextUnixSock() + rpopt.Network = "unix" + reverseProxyEngine := route.NewEngine(rpopt) + reverseProxyEngine.POST("/", func(c context.Context, ctx *app.RequestContext) { + body := make([]byte, 1024*1024*1024) + ctx.Write(body) //nolint:errcheck + }) + go reverseProxyEngine.Run() + defer func() { + reverseProxyEngine.Close() + }() + // start engine + opt := config.NewOptions([]config.Option{}) + opt.Addr = nextUnixSock() + opt.Network = "unix" + engine := route.NewEngine(opt) + reverseProxyClient, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(rpopt.Network, rpopt.Addr, 1*time.Second, nil)), WithResponseBodyStream(true)) + engine.POST("/", func(c context.Context, ctx *app.RequestContext) { + req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() + req.SetRequestURI("http://example.com") + req.SetMethod("POST") + err := reverseProxyClient.Do(c, req, resp) + if err != nil { + t.Fatal(err) + } + ctx.SetBodyStream(resp.BodyStream(), -1) + }) + go engine.Run() + defer func() { + engine.Close() + }() + + time.Sleep(time.Second) + req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() + req.SetRequestURI("http://example.com") + req.SetMethod("POST") + cli, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil)), WithResponseBodyStream(true)) + + runtime.ReadMemStats(&ms) + preHeapAlloc, preHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects + t.Logf("After init env, allocation: %f Mb, Number of allocation: %d\n", preHeapAlloc, preHeapObjects) + + err := cli.Do(context.Background(), req, resp) + if err != nil { + t.Fatal(err) + } + if fn, ok := resp.BodyStream().(io.Closer); ok { + fn.Close() + } + + // Trigger the finalizer of kclient be executed + time.Sleep(200 * time.Millisecond) // ensure the finalizer be executed + runtime.GC() + runtime.GC() + runtime.ReadMemStats(&ms) + secondGCHeapAlloc, secondGCHeapObjects := mb(ms.HeapAlloc), ms.HeapObjects + t.Logf("After sending request, allocation: %f Mb, Number of allocation: %d\n", secondGCHeapAlloc, secondGCHeapObjects) + runtime.GC() + if secondGCHeapAlloc/2 > preHeapAlloc { + // using t.Fatalf will cause memory not recycled. So we use panic here + panic(fmt.Sprintf("memory leak, preHeapAlloc: %f, secondGCHeapAlloc: %f", preHeapAlloc, secondGCHeapAlloc)) + } +} + func TestClientURLAuth(t *testing.T) { cases := map[string]string{ "foo:bar@": "Basic Zm9vOmJhcg==", diff --git a/pkg/common/test/mock/network.go b/pkg/common/test/mock/network.go index fe671ae56..5b67142d8 100644 --- a/pkg/common/test/mock/network.go +++ b/pkg/common/test/mock/network.go @@ -325,7 +325,8 @@ func (m *Conn) AddCloseCallback(callback netpoll.CloseCallback) error { } type StreamConn struct { - Data []byte + Data []byte + HasReleased bool } func NewStreamConn() *StreamConn { @@ -354,7 +355,8 @@ func (m *StreamConn) Skip(n int) error { } func (m *StreamConn) Release() error { - panic("implement me") + m.HasReleased = true + return nil } func (m *StreamConn) Len() int { diff --git a/pkg/protocol/http1/ext/stream.go b/pkg/protocol/http1/ext/stream.go index ae81b560e..51bb01d43 100644 --- a/pkg/protocol/http1/ext/stream.go +++ b/pkg/protocol/http1/ext/stream.go @@ -272,6 +272,11 @@ func (rs *bodyStream) skipRest() error { if err != nil { return err } + err = rs.reader.Release() + if err != nil { + return err + } + } } // max value of pSize is 8193, it's safe. @@ -300,7 +305,14 @@ func (rs *bodyStream) skipRest() error { if skip > needSkipLen { skip = needSkipLen } - rs.reader.Skip(skip) + err := rs.reader.Skip(skip) + if err != nil { + return err + } + err = rs.reader.Release() + if err != nil { + return err + } needSkipLen -= skip if needSkipLen == 0 { return nil diff --git a/pkg/protocol/http1/req/request_test.go b/pkg/protocol/http1/req/request_test.go index 0411187a5..e9451e776 100644 --- a/pkg/protocol/http1/req/request_test.go +++ b/pkg/protocol/http1/req/request_test.go @@ -1425,6 +1425,7 @@ func TestStreamNotEnoughData(t *testing.T) { err = ext.ReleaseBodyStream(req.BodyStream()) assert.Nil(t, err) assert.DeepEqual(t, 0, len(conn.Data)) + assert.DeepEqual(t, true, conn.HasReleased) } func TestRequestBodyStreamWithTrailer(t *testing.T) {