Skip to content

Commit

Permalink
chore: optimize recover logic (#133)
Browse files Browse the repository at this point in the history
* chore: optimize recover logic

* add test cases

* optimize
  • Loading branch information
hwbrzzl authored Dec 30, 2024
1 parent 495d8c8 commit 19f4600
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 19 deletions.
7 changes: 1 addition & 6 deletions middleware_timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@ func Timeout(timeout time.Duration) contractshttp.Middleware {
go func() {
defer func() {
if err := recover(); err != nil {
if globalRecoverCallback != nil {
globalRecoverCallback(ctx, err)
} else {
LogFacade.Error(err)
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, fiber.Map{"error": "Internal Server Error"})
}
globalRecoverCallback(ctx, err)
}

close(done)
Expand Down
40 changes: 31 additions & 9 deletions middleware_timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ import (
"github.com/gofiber/fiber/v2"
contractshttp "github.com/goravel/framework/contracts/http"
mocksconfig "github.com/goravel/framework/mocks/config"
mockslog "github.com/goravel/framework/mocks/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

func TestTimeoutMiddleware(t *testing.T) {
mockConfig := mocksconfig.NewConfig(t)
mockConfig.On("GetBool", "http.drivers.fiber.prefork", false).Return(false).Once()
mockConfig.On("GetInt", "http.drivers.fiber.body_limit", 4096).Return(4096).Once()
mockConfig.On("GetInt", "http.drivers.fiber.header_limit", 4096).Return(4096).Once()
mockConfig.EXPECT().GetBool("http.drivers.fiber.prefork", false).Return(false).Once()
mockConfig.EXPECT().GetInt("http.drivers.fiber.body_limit", 4096).Return(4096).Once()
mockConfig.EXPECT().GetInt("http.drivers.fiber.header_limit", 4096).Return(4096).Once()

route, err := NewRoute(mockConfig, nil)
require.NoError(t, err)
Expand All @@ -35,11 +37,6 @@ func TestTimeoutMiddleware(t *testing.T) {
panic("test panic")
})

globalRecover := func(ctx contractshttp.Context, err any) {
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, fiber.Map{"error": "Internal Panic"})
}
route.Recover(globalRecover)

t.Run("timeout", func(t *testing.T) {
req, err := http.NewRequest("GET", "/timeout", nil)
require.NoError(t, err)
Expand Down Expand Up @@ -69,7 +66,32 @@ func TestTimeoutMiddleware(t *testing.T) {
assert.Equal(t, "normal", string(body))
})

t.Run("panic", func(t *testing.T) {
t.Run("panic with default recover", func(t *testing.T) {
mockLog := mockslog.NewLog(t)
mockLog.EXPECT().WithContext(mock.Anything).Return(mockLog).Once()
mockLog.EXPECT().Request(mock.Anything).Return(mockLog).Once()
mockLog.EXPECT().Error("test panic").Once()
LogFacade = mockLog

req, err := http.NewRequest("GET", "/panic", nil)
require.NoError(t, err)

resp, err := route.instance.Test(req, -1)
require.NoError(t, err)

assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "Internal Server Error", string(body))
})

t.Run("panic with custom recover", func(t *testing.T) {
globalRecover := func(ctx contractshttp.Context, err any) {
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, fiber.Map{"error": "Internal Panic"})
}
route.Recover(globalRecover)

req, err := http.NewRequest("GET", "/panic", nil)
require.NoError(t, err)

Expand Down
9 changes: 5 additions & 4 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ import (
"github.com/savioxavier/termlink"
)

var globalRecoverCallback func(ctx httpcontract.Context, err any)
var globalRecoverCallback func(ctx httpcontract.Context, err any) = func(ctx httpcontract.Context, err any) {
LogFacade.WithContext(ctx).Request(ctx.Request()).Error(err)
ctx.Request().AbortWithStatus(http.StatusInternalServerError)
}

// Route fiber route
// Route fiber 路由
Expand Down Expand Up @@ -138,9 +141,7 @@ func (r *Route) Recover(callback func(ctx httpcontract.Context, err any)) {
func(ctx httpcontract.Context) {
defer func() {
if err := recover(); err != nil {
if callback != nil {
callback(ctx, err)
}
callback(ctx, err)
}
}()
ctx.Request().Next()
Expand Down

0 comments on commit 19f4600

Please sign in to comment.