Skip to content

Commit

Permalink
feat: add dynamic routing
Browse files Browse the repository at this point in the history
  • Loading branch information
zjregee committed Aug 26, 2024
1 parent be17780 commit 92a2c5e
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 49 deletions.
21 changes: 18 additions & 3 deletions ahttp/ahttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,16 @@ func (s *Server) Use(middleware ...MiddlewareFunc) {
s.middleware = append(s.middleware, middleware...)
}

func (s *Server) Add(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
s.router.add(path, func(c *Context) error {
func (s *Server) GET(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
s.add(http.MethodGet, path, handler, middleware...)
}

func (s *Server) POST(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
s.add(http.MethodPost, path, handler, middleware...)
}

func (s *Server) add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
s.router.add(method, path, func(c *Context) error {
h := applyMiddleware(handler, middleware...)
return h(c)
})
Expand All @@ -93,7 +101,14 @@ func (s *Server) handleConnection(_ context.Context, connection anet.Connection)
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) {
c := s.pool.Get().(*Context)
c.Reset(r, w)
h := s.router.find(getPath(r))
h, params := s.router.find(r.Method, getPath(r))
if h != nil {
for k, v := range params {
c.Set(k, v)
}
} else {
h = NotFoundHandler
}
c.SetHandler(h)
h = applyMiddleware(h, s.middleware...)
_ = h(c)
Expand Down
126 changes: 84 additions & 42 deletions ahttp/router.go
Original file line number Diff line number Diff line change
@@ -1,62 +1,104 @@
package ahttp

func newRouter() *router {
return &router{
root: &node{children: make(map[byte]*node)},
routes: make(map[string]HandlerFunc),
import (
"strings"
)

type node struct {
part string
isParam bool
isWild bool
isEnd bool
children map[string]*node
handler HandlerFunc
}

func newNode(part string, isParam, isWild bool) *node {
return &node{
part: part,
isParam: isParam,
isWild: isWild,
children: make(map[string]*node),
}
}

type router struct {
root *node
routes map[string]HandlerFunc
func (n *node) insert(parts []string, handler HandlerFunc) {
node := n
for i := 0; i < len(parts); i++ {
part := parts[i]
isParam := strings.HasPrefix(part, ":")
isWild := strings.HasPrefix(part, "*")
child, ok := node.children[part]
if !ok {
child = newNode(part, isParam, isWild)
node.children[part] = child
}
node = child
}
node.isEnd = true
node.handler = handler
}

type node struct {
children map[byte]*node
isEnd bool
func (n *node) search(parts []string) (HandlerFunc, map[string]string) {
node := n
params := make(map[string]string)
for _, part := range parts {
child, ok := node.children[part]
if !ok {
for _, childNode := range node.children {
if childNode.isParam {
params[childNode.part[1:]] = part
child = childNode
break
}
if childNode.isWild {
params[childNode.part[1:]] = strings.Join(parts, "/")
child = childNode
break
}
}
}
if child == nil {
return nil, nil
}
node = child
}
if node.isEnd {
return node.handler, params
}
return nil, nil
}

func (r *router) add(path string, h HandlerFunc) {
path = normalizePathSlash(path)
r.routes[path] = h
r.insert(path)
func newRouter() *router {
return &router{
routes: make(map[string]*node),
}
}

func (r *router) find(path string) HandlerFunc {
path = normalizePathSlash(path)
prefix := r.longestPrefixMatch(path)
return r.routes[prefix]
type router struct {
routes map[string]*node
}

func (r *router) insert(prefix string) {
n := r.root
for i := 0; i < len(prefix); i++ {
char := prefix[i]
if _, found := n.children[char]; !found {
n.children[char] = &node{children: make(map[byte]*node)}
}
n = n.children[char]
func (r *router) add(method, path string, h HandlerFunc) {
path = normalizePathSlash(path)
node, ok := r.routes[method]
if !ok {
node = newNode("/", false, false)
r.routes[method] = node
}
n.isEnd = true
parts := strings.Split(path, "/")
node.insert(parts, h)
}

func (r *router) longestPrefixMatch(query string) string {
n := r.root
longestPrefix := ""
currentPrefix := ""
for i := 0; i < len(query); i++ {
char := query[i]
if _, found := n.children[char]; !found {
break
}
n = n.children[char]
currentPrefix += string(char)
if n.isEnd {
longestPrefix = currentPrefix
}
func (r *router) find(method, path string) (HandlerFunc, map[string]string) {
path = normalizePathSlash(path)
node, ok := r.routes[method]
if !ok {
return nil, nil
}
return longestPrefix
parts := strings.Split(path, "/")
h, params := node.search(parts)
return h, params
}

func normalizePathSlash(path string) string {
Expand Down
40 changes: 37 additions & 3 deletions ahttp/router_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,50 @@
package ahttp

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

func TestRouter(t *testing.T) {
r := newRouter()
r.add("/test", func(c *Context) error {
r.add(http.MethodGet, "/test/get", func(c *Context) error {
return nil
})
assert.NotNil(t, r.find("/test"))
assert.Nil(t, r.find("/notfound"))
r.add(http.MethodPost, "/test/post", func(c *Context) error {
return nil
})
h, _ := r.find(http.MethodGet, "/test/get")
assert.NotNil(t, h)
h, _ = r.find(http.MethodPost, "/test/post")
assert.NotNil(t, h)
h, _ = r.find(http.MethodGet, "/notfound")
assert.Nil(t, h)
}

func TestRouterWithParam(t *testing.T) {
r := newRouter()
r.add(http.MethodGet, "/test/name/:name", func(c *Context) error {
return nil
})
r.add(http.MethodGet, "/test/name/:name/age/:age", func(c *Context) error {
return nil
})
h, params := r.find(http.MethodGet, "/test/name/abc")
assert.NotNil(t, h)
assert.Equal(t, "abc", params["name"])
h, params = r.find(http.MethodGet, "/test/name/abc/age/10")
assert.NotNil(t, h)
assert.Equal(t, "abc", params["name"])
assert.Equal(t, "10", params["age"])
}

func TestRouterWithWild(t *testing.T) {
r := newRouter()
r.add(http.MethodGet, "/test/*", func(c *Context) error {
return nil
})
h, _ := r.find(http.MethodGet, "/test/wild")
assert.NotNil(t, h)
}
2 changes: 1 addition & 1 deletion e2e_test/http_server/http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

func runServer(port string, stopChan chan interface{}) {
server := ahttp.New()
server.Add("/test", func(c *ahttp.Context) error {
server.GET("/test", func(c *ahttp.Context) error {
return c.NoContent(http.StatusOK)
})

Expand Down

0 comments on commit 92a2c5e

Please sign in to comment.