Skip to content

Commit

Permalink
fix: file methods and GlobalMiddleware method (#35)
Browse files Browse the repository at this point in the history
* fix: file methods

* fix: middlewares

* fix: tests

* feat: optimize path

* feat: optimize path tests

* fix: lint

* fix: ???

* fix: url error

* fix: tests

* fix: panic when route undefined

* fix: prefix problem

* feat: optimize Fallback method
  • Loading branch information
devhaozi authored Nov 4, 2023
1 parent 69fe2f0 commit c3c3624
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 86 deletions.
87 changes: 68 additions & 19 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package fiber

import (
"net/http"
"net/url"
"path/filepath"

"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/filesystem"
Expand All @@ -15,16 +17,18 @@ type Group struct {
instance *fiber.App
originPrefix string
prefix string
globalMiddlewares []any
originMiddlewares []httpcontract.Middleware
middlewares []httpcontract.Middleware
lastMiddlewares []httpcontract.Middleware
}

func NewGroup(config config.Config, instance *fiber.App, prefix string, originMiddlewares []httpcontract.Middleware, lastMiddlewares []httpcontract.Middleware) route.Router {
func NewGroup(config config.Config, instance *fiber.App, prefix string, globalMiddlewares []any, originMiddlewares []httpcontract.Middleware, lastMiddlewares []httpcontract.Middleware) route.Router {
return &Group{
config: config,
instance: instance,
originPrefix: prefix,
globalMiddlewares: globalMiddlewares,
originMiddlewares: originMiddlewares,
lastMiddlewares: lastMiddlewares,
}
Expand All @@ -38,7 +42,7 @@ func (r *Group) Group(handler route.GroupFunc) {
prefix := pathToFiberPath(r.originPrefix + "/" + r.prefix)
r.prefix = ""

handler(NewGroup(r.config, r.instance, prefix, middlewares, r.lastMiddlewares))
handler(NewGroup(r.config, r.instance, prefix, r.globalMiddlewares, middlewares, r.lastMiddlewares))
}

func (r *Group) Prefix(addr string) route.Router {
Expand All @@ -54,61 +58,79 @@ func (r *Group) Middleware(middlewares ...httpcontract.Middleware) route.Router
}

func (r *Group) Any(relativePath string, handler httpcontract.HandlerFunc) {
r.instance.All(r.getPath(relativePath), r.getMiddlewares(handler)...)
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).All(relativePath, r.getMiddlewares(handler)...)
r.clearMiddlewares()
}

func (r *Group) Get(relativePath string, handler httpcontract.HandlerFunc) {
r.instance.Get(r.getPath(relativePath), r.getMiddlewares(handler)...)
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Get(relativePath, r.getMiddlewares(handler)...)
r.clearMiddlewares()
}

func (r *Group) Post(relativePath string, handler httpcontract.HandlerFunc) {
r.instance.Post(r.getPath(relativePath), r.getMiddlewares(handler)...)
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Post(relativePath, r.getMiddlewares(handler)...)
r.clearMiddlewares()
}

func (r *Group) Delete(relativePath string, handler httpcontract.HandlerFunc) {
r.instance.Delete(r.getPath(relativePath), r.getMiddlewares(handler)...)
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Delete(relativePath, r.getMiddlewares(handler)...)
r.clearMiddlewares()
}

func (r *Group) Patch(relativePath string, handler httpcontract.HandlerFunc) {
r.instance.Patch(r.getPath(relativePath), r.getMiddlewares(handler)...)
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Patch(relativePath, r.getMiddlewares(handler)...)
r.clearMiddlewares()
}

func (r *Group) Put(relativePath string, handler httpcontract.HandlerFunc) {
r.instance.Put(r.getPath(relativePath), r.getMiddlewares(handler)...)
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Put(relativePath, r.getMiddlewares(handler)...)
r.clearMiddlewares()
}

func (r *Group) Options(relativePath string, handler httpcontract.HandlerFunc) {
r.instance.Options(r.getPath(relativePath), r.getMiddlewares(handler)...)
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Options(relativePath, r.getMiddlewares(handler)...)
r.clearMiddlewares()
}

func (r *Group) Resource(relativePath string, controller httpcontract.ResourceController) {
r.instance.Get(r.getPath(relativePath), r.getMiddlewares(controller.Index)...)
r.instance.Post(r.getPath(relativePath), r.getMiddlewares(controller.Store)...)
r.instance.Get(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Show)...)
r.instance.Put(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Update)...)
r.instance.Patch(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Update)...)
r.instance.Delete(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Destroy)...)
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Get(relativePath, r.getMiddlewares(controller.Index)...)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Post(relativePath, r.getMiddlewares(controller.Store)...)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Get(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Show)...)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Put(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Update)...)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Patch(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Update)...)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Delete(r.getPath(relativePath+"/{id}"), r.getMiddlewares(controller.Destroy)...)
r.clearMiddlewares()
}

func (r *Group) Static(relativePath, root string) {
r.instance.Group(r.getPath(""), r.getMiddlewares(nil)...).Static(pathToFiberPath(relativePath), root)
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Use(r.getMiddlewaresWithPath(relativePath, nil)...).Static(relativePath, root)
r.clearMiddlewares()
}

func (r *Group) StaticFile(relativePath, filepath string) {
r.Static(relativePath, filepath)
func (r *Group) StaticFile(relativePath, filePath string) {
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Use(r.getMiddlewaresWithPath(relativePath, nil)...).Use(relativePath, func(c *fiber.Ctx) error {
dir, file := filepath.Split(filePath)
escapedFile := url.PathEscape(file)
escapedPath := filepath.Join(dir, escapedFile)

return c.SendFile(escapedPath, true)
})
r.clearMiddlewares()
}

func (r *Group) StaticFS(relativePath string, fs http.FileSystem) {
r.instance.Group(r.getPath(""), r.getMiddlewares(nil)...).Use(pathToFiberPath(relativePath), filesystem.New(filesystem.Config{
relativePath = r.getPath(relativePath)
r.instance.Use(r.getGlobalMiddlewaresWithPath(relativePath)...).Use(r.getMiddlewaresWithPath(relativePath, nil)...).Use(relativePath, filesystem.New(filesystem.Config{
Root: fs,
}))
r.clearMiddlewares()
Expand All @@ -133,6 +155,33 @@ func (r *Group) getPath(relativePath string) string {
return path
}

func (r *Group) getMiddlewaresWithPath(relativePath string, handler httpcontract.HandlerFunc) []any {
var handlers []any
handlers = append(handlers, relativePath)
middlewares := r.getMiddlewares(handler)

// Fiber will panic if no middleware is provided, So we add a dummy middleware
if len(middlewares) == 0 {
middlewares = append(middlewares, func(c *fiber.Ctx) error {
return c.Next()
})
}

for _, item := range middlewares {
handlers = append(handlers, item)
}

return handlers
}

func (r *Group) getGlobalMiddlewaresWithPath(relativePath string) []any {
var handlers []any
handlers = append(handlers, relativePath)
handlers = append(handlers, r.globalMiddlewares...)

return handlers
}

func (r *Group) clearMiddlewares() {
r.middlewares = []httpcontract.Middleware{}
}
117 changes: 70 additions & 47 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package fiber
import (
"io"
"net/http"
"os"
"path/filepath"
"testing"

contractshttp "github.com/goravel/framework/contracts/http"
Expand Down Expand Up @@ -261,13 +263,13 @@ func TestGroup(t *testing.T) {
setup: func(req *http.Request) {
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Times(3)
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Times(3)
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Times(3)
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Times(3)
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Times(3)
mockConfig.On("GetInt", "cors.max_age").Return(0).Times(3)
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Times(3)

resource := resourceController{}
fiber.GlobalMiddleware(func(ctx contractshttp.Context) {
Expand All @@ -286,13 +288,13 @@ func TestGroup(t *testing.T) {
setup: func(req *http.Request) {
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Times(2)
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Times(2)
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Times(2)
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Times(2)
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Times(2)
mockConfig.On("GetInt", "cors.max_age").Return(0).Times(2)
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Times(2)

resource := resourceController{}
fiber.GlobalMiddleware(func(ctx contractshttp.Context) {
Expand All @@ -311,13 +313,13 @@ func TestGroup(t *testing.T) {
setup: func(req *http.Request) {
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Times(4)
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Times(4)
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Times(4)
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Times(4)
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Times(4)
mockConfig.On("GetInt", "cors.max_age").Return(0).Times(4)
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Times(4)

resource := resourceController{}
fiber.GlobalMiddleware(func(ctx contractshttp.Context) {
Expand All @@ -336,13 +338,13 @@ func TestGroup(t *testing.T) {
setup: func(req *http.Request) {
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Times(5)
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Times(5)
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Times(5)
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Times(5)
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Times(5)
mockConfig.On("GetInt", "cors.max_age").Return(0).Times(5)
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Times(5)

resource := resourceController{}
fiber.GlobalMiddleware(func(ctx contractshttp.Context) {
Expand All @@ -361,13 +363,13 @@ func TestGroup(t *testing.T) {
setup: func(req *http.Request) {
mockConfig.On("GetBool", "app.debug", false).Return(true).Twice()
mockConfig.On("GetString", "app.timezone", "UTC").Return("UTC").Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Once()
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Once()
mockConfig.On("GetInt", "cors.max_age").Return(0).Once()
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Once()
mockConfig.On("Get", "cors.paths").Return([]string{"*"}).Times(6)
mockConfig.On("Get", "cors.allowed_methods").Return([]string{"*"}).Times(6)
mockConfig.On("Get", "cors.allowed_origins").Return([]string{"*"}).Times(6)
mockConfig.On("Get", "cors.allowed_headers").Return([]string{"*"}).Times(6)
mockConfig.On("Get", "cors.exposed_headers").Return([]string{"*"}).Times(6)
mockConfig.On("GetInt", "cors.max_age").Return(0).Times(6)
mockConfig.On("GetBool", "cors.supports_credentials").Return(false).Times(6)

resource := resourceController{}
fiber.GlobalMiddleware(func(ctx contractshttp.Context) {
Expand All @@ -384,29 +386,50 @@ func TestGroup(t *testing.T) {
{
name: "Static",
setup: func(req *http.Request) {
fiber.Static("static", "./")
tempDir, err := os.MkdirTemp("", "test")
assert.NoError(t, err)

err = os.WriteFile(filepath.Join(tempDir, "test.json"), []byte("{\"id\":1}"), 0755)
assert.NoError(t, err)

fiber.Static("static", tempDir)
},
method: "GET",
url: "/static/README.md",
expectCode: http.StatusOK,
method: "GET",
url: "/static/test.json",
expectCode: http.StatusOK,
expectBodyJson: "{\"id\":1}",
},
{
name: "StaticFile",
setup: func(req *http.Request) {
fiber.StaticFile("static-file", "./README.md")
file, err := os.CreateTemp("", "test")
assert.NoError(t, err)

err = os.WriteFile(file.Name(), []byte("{\"id\":1}"), 0755)
assert.NoError(t, err)

fiber.StaticFile("static-file", file.Name())
},
method: "GET",
url: "/static-file",
expectCode: http.StatusOK,
method: "GET",
url: "/static-file",
expectCode: http.StatusOK,
expectBodyJson: "{\"id\":1}",
},
{
name: "StaticFS",
setup: func(req *http.Request) {
fiber.StaticFS("static-fs", http.Dir("./"))
tempDir, err := os.MkdirTemp("", "test")
assert.NoError(t, err)

err = os.WriteFile(filepath.Join(tempDir, "test.json"), []byte("{\"id\":1}"), 0755)
assert.NoError(t, err)

fiber.StaticFS("static-fs", http.Dir(tempDir))
},
method: "GET",
url: "/static-fs/README.md",
expectCode: http.StatusOK,
method: "GET",
url: "/static-fs/test.json",
expectCode: http.StatusOK,
expectBodyJson: "{\"id\":1}",
},
{
name: "Abort Middleware",
Expand Down
9 changes: 8 additions & 1 deletion response.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package fiber

import (
"net/url"
"path/filepath"

"github.com/gofiber/fiber/v2"
)

Expand Down Expand Up @@ -32,7 +35,11 @@ type FileResponse struct {
}

func (r *FileResponse) Render() error {
return r.instance.SendFile(r.filepath)
dir, file := filepath.Split(r.filepath)
escapedFile := url.PathEscape(file)
escapedPath := filepath.Join(dir, escapedFile)

return r.instance.SendFile(escapedPath, true)
}

type JsonResponse struct {
Expand Down
Loading

0 comments on commit c3c3624

Please sign in to comment.