diff --git a/ahttp/ahttp.go b/ahttp/ahttp.go index f58e1b5..7a1e9f8 100644 --- a/ahttp/ahttp.go +++ b/ahttp/ahttp.go @@ -66,8 +66,16 @@ func (s *Server) Use(middleware ...MiddlewareFunc) { s.middleware = append(s.middleware, middleware...) } -func (s *Server) Add(path string, handler HandlerFunc, middleware ...MiddlewareFunc) { - s.router.add(path, func(c *Context) error { +func (s *Server) GET(path string, handler HandlerFunc, middleware ...MiddlewareFunc) { + s.add(http.MethodGet, path, handler, middleware...) +} + +func (s *Server) POST(path string, handler HandlerFunc, middleware ...MiddlewareFunc) { + s.add(http.MethodPost, path, handler, middleware...) +} + +func (s *Server) add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) { + s.router.add(method, path, func(c *Context) error { h := applyMiddleware(handler, middleware...) return h(c) }) @@ -93,7 +101,14 @@ func (s *Server) handleConnection(_ context.Context, connection anet.Connection) func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) { c := s.pool.Get().(*Context) c.Reset(r, w) - h := s.router.find(getPath(r)) + h, params := s.router.find(r.Method, getPath(r)) + if h != nil { + for k, v := range params { + c.Set(k, v) + } + } else { + h = NotFoundHandler + } c.SetHandler(h) h = applyMiddleware(h, s.middleware...) _ = h(c) diff --git a/ahttp/router.go b/ahttp/router.go index 3596605..4d6a9b2 100644 --- a/ahttp/router.go +++ b/ahttp/router.go @@ -1,62 +1,104 @@ package ahttp -func newRouter() *router { - return &router{ - root: &node{children: make(map[byte]*node)}, - routes: make(map[string]HandlerFunc), +import ( + "strings" +) + +type node struct { + part string + isParam bool + isWild bool + isEnd bool + children map[string]*node + handler HandlerFunc +} + +func newNode(part string, isParam, isWild bool) *node { + return &node{ + part: part, + isParam: isParam, + isWild: isWild, + children: make(map[string]*node), } } -type router struct { - root *node - routes map[string]HandlerFunc +func (n *node) insert(parts []string, handler HandlerFunc) { + node := n + for i := 0; i < len(parts); i++ { + part := parts[i] + isParam := strings.HasPrefix(part, ":") + isWild := strings.HasPrefix(part, "*") + child, ok := node.children[part] + if !ok { + child = newNode(part, isParam, isWild) + node.children[part] = child + } + node = child + } + node.isEnd = true + node.handler = handler } -type node struct { - children map[byte]*node - isEnd bool +func (n *node) search(parts []string) (HandlerFunc, map[string]string) { + node := n + params := make(map[string]string) + for _, part := range parts { + child, ok := node.children[part] + if !ok { + for _, childNode := range node.children { + if childNode.isParam { + params[childNode.part[1:]] = part + child = childNode + break + } + if childNode.isWild { + params[childNode.part[1:]] = strings.Join(parts, "/") + child = childNode + break + } + } + } + if child == nil { + return nil, nil + } + node = child + } + if node.isEnd { + return node.handler, params + } + return nil, nil } -func (r *router) add(path string, h HandlerFunc) { - path = normalizePathSlash(path) - r.routes[path] = h - r.insert(path) +func newRouter() *router { + return &router{ + routes: make(map[string]*node), + } } -func (r *router) find(path string) HandlerFunc { - path = normalizePathSlash(path) - prefix := r.longestPrefixMatch(path) - return r.routes[prefix] +type router struct { + routes map[string]*node } -func (r *router) insert(prefix string) { - n := r.root - for i := 0; i < len(prefix); i++ { - char := prefix[i] - if _, found := n.children[char]; !found { - n.children[char] = &node{children: make(map[byte]*node)} - } - n = n.children[char] +func (r *router) add(method, path string, h HandlerFunc) { + path = normalizePathSlash(path) + node, ok := r.routes[method] + if !ok { + node = newNode("/", false, false) + r.routes[method] = node } - n.isEnd = true + parts := strings.Split(path, "/") + node.insert(parts, h) } -func (r *router) longestPrefixMatch(query string) string { - n := r.root - longestPrefix := "" - currentPrefix := "" - for i := 0; i < len(query); i++ { - char := query[i] - if _, found := n.children[char]; !found { - break - } - n = n.children[char] - currentPrefix += string(char) - if n.isEnd { - longestPrefix = currentPrefix - } +func (r *router) find(method, path string) (HandlerFunc, map[string]string) { + path = normalizePathSlash(path) + node, ok := r.routes[method] + if !ok { + return nil, nil } - return longestPrefix + parts := strings.Split(path, "/") + h, params := node.search(parts) + return h, params } func normalizePathSlash(path string) string { diff --git a/ahttp/router_test.go b/ahttp/router_test.go index 11b5377..9836a78 100644 --- a/ahttp/router_test.go +++ b/ahttp/router_test.go @@ -1,6 +1,7 @@ package ahttp import ( + "net/http" "testing" "github.com/stretchr/testify/assert" @@ -8,9 +9,42 @@ import ( func TestRouter(t *testing.T) { r := newRouter() - r.add("/test", func(c *Context) error { + r.add(http.MethodGet, "/test/get", func(c *Context) error { return nil }) - assert.NotNil(t, r.find("/test")) - assert.Nil(t, r.find("/notfound")) + r.add(http.MethodPost, "/test/post", func(c *Context) error { + return nil + }) + h, _ := r.find(http.MethodGet, "/test/get") + assert.NotNil(t, h) + h, _ = r.find(http.MethodPost, "/test/post") + assert.NotNil(t, h) + h, _ = r.find(http.MethodGet, "/notfound") + assert.Nil(t, h) +} + +func TestRouterWithParam(t *testing.T) { + r := newRouter() + r.add(http.MethodGet, "/test/name/:name", func(c *Context) error { + return nil + }) + r.add(http.MethodGet, "/test/name/:name/age/:age", func(c *Context) error { + return nil + }) + h, params := r.find(http.MethodGet, "/test/name/abc") + assert.NotNil(t, h) + assert.Equal(t, "abc", params["name"]) + h, params = r.find(http.MethodGet, "/test/name/abc/age/10") + assert.NotNil(t, h) + assert.Equal(t, "abc", params["name"]) + assert.Equal(t, "10", params["age"]) +} + +func TestRouterWithWild(t *testing.T) { + r := newRouter() + r.add(http.MethodGet, "/test/*", func(c *Context) error { + return nil + }) + h, _ := r.find(http.MethodGet, "/test/wild") + assert.NotNil(t, h) } diff --git a/e2e_test/http_server/http_server.go b/e2e_test/http_server/http_server.go index 9b839c1..d47c462 100644 --- a/e2e_test/http_server/http_server.go +++ b/e2e_test/http_server/http_server.go @@ -8,7 +8,7 @@ import ( func runServer(port string, stopChan chan interface{}) { server := ahttp.New() - server.Add("/test", func(c *ahttp.Context) error { + server.GET("/test", func(c *ahttp.Context) error { return c.NoContent(http.StatusOK) })