Skip to content

Commit

Permalink
chore: allow handlers to have params with different names
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Aug 30, 2022
1 parent 86ff925 commit 1a74492
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 51 deletions.
15 changes: 8 additions & 7 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,16 @@ func (g *Group) Handle(meth string, path string, handler HandlerFunc) {
panic(fmt.Errorf("routes %q and %q can't both handle %s", node.route, path, meth))
}
}
node.setHandler(meth, g.wrap(handler))

if !paramsEqual(node.params, params) {
panic(fmt.Errorf("routes %q and %q have different param names for the same route",
node.route, path))
}
node.setHandler(meth, &routeHandler{
fn: g.wrap(handler),
params: params,
})

if node.handlerMap.notAllowed == nil {
node.handlerMap.notAllowed = g.wrap(g.router.methodNotAllowedHandler)
node.handlerMap.notAllowed = &routeHandler{
fn: g.wrap(g.router.methodNotAllowedHandler),
params: params,
}
}
}

Expand Down
45 changes: 18 additions & 27 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ type node struct {
route string
part string

params map[string]int // param name => param position
handlerMap *handlerMap

parent *node
Expand All @@ -36,7 +35,6 @@ func (n *node) addRoute(route string) (*node, map[string]int) {

if currNode.route == "" {
currNode.route = route
currNode.params = params
}
n.indexNodes()

Expand Down Expand Up @@ -108,7 +106,7 @@ func (n *node) addPart(part string) *node {
return node
}

func (n *node) findRoute(meth, path string) (*node, HandlerFunc, int) {
func (n *node) findRoute(meth, path string) (*node, *routeHandler, int) {
if path == "" {
return nil, nil, 0
}
Expand All @@ -124,7 +122,7 @@ func (n *node) findRoute(meth, path string) (*node, HandlerFunc, int) {
return n._findRoute(meth, path)
}

func (n *node) _findRoute(meth, path string) (*node, HandlerFunc, int) {
func (n *node) _findRoute(meth, path string) (*node, *routeHandler, int) {
var found *node

if firstChar := path[0]; firstChar >= n.index.minChar && firstChar <= n.index.maxChar {
Expand Down Expand Up @@ -219,7 +217,7 @@ func (n *node) _indexNodes() {
}
}

func (n *node) setHandler(verb string, handler HandlerFunc) {
func (n *node) setHandler(verb string, handler *routeHandler) {
if n.handlerMap == nil {
n.handlerMap = newHandlerMap()
}
Expand Down Expand Up @@ -336,21 +334,26 @@ func paramMap(route string, params []string) map[string]int {
//------------------------------------------------------------------------------

type handlerMap struct {
get HandlerFunc
post HandlerFunc
put HandlerFunc
delete HandlerFunc
head HandlerFunc
options HandlerFunc
patch HandlerFunc
notAllowed HandlerFunc
get *routeHandler
post *routeHandler
put *routeHandler
delete *routeHandler
head *routeHandler
options *routeHandler
patch *routeHandler
notAllowed *routeHandler
}

type routeHandler struct {
fn HandlerFunc
params map[string]int // param name => param position
}

func newHandlerMap() *handlerMap {
return new(handlerMap)
}

func (h *handlerMap) Get(meth string) HandlerFunc {
func (h *handlerMap) Get(meth string) *routeHandler {
switch meth {
case http.MethodGet:
return h.get
Expand All @@ -371,7 +374,7 @@ func (h *handlerMap) Get(meth string) HandlerFunc {
}
}

func (h *handlerMap) Set(meth string, handler HandlerFunc) {
func (h *handlerMap) Set(meth string, handler *routeHandler) {
switch meth {
case http.MethodGet:
h.get = handler
Expand All @@ -391,15 +394,3 @@ func (h *handlerMap) Set(meth string, handler HandlerFunc) {
panic(fmt.Errorf("unknown HTTP method: %s", meth))
}
}

func paramsEqual(m1, m2 map[string]int) bool {
if len(m1) != len(m2) {
return false
}
for k, v1 := range m1 {
if v2, ok := m2[k]; !ok || v1 != v2 {
return false
}
}
return true
}
17 changes: 9 additions & 8 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ func (req Request) Route() string {
type Params struct {
path string
node *node
handler *routeHandler
wildcardLen uint16
}

Expand All @@ -122,7 +123,7 @@ func (ps Params) Get(name string) (string, bool) {
if ps.node == nil {
return "", false
}
if i, ok := ps.node.params[name]; ok {
if i, ok := ps.handler.params[name]; ok {
return ps.findParam(i)
}
return "", false
Expand All @@ -132,7 +133,7 @@ func (ps *Params) findParam(paramIndex int) (string, bool) {
path := ps.path
pathLen := len(path)
currNode := ps.node
currParamIndex := len(ps.node.params) - 1
currParamIndex := len(ps.handler.params) - 1

// Wildcard can be only in the final node.
if ps.node.isWC {
Expand Down Expand Up @@ -196,11 +197,11 @@ func (ps Params) Int64(name string) (int64, error) {
}

func (ps Params) Map() map[string]string {
if ps.node == nil || len(ps.node.params) == 0 {
if ps.handler == nil || len(ps.handler.params) == 0 {
return nil
}
m := make(map[string]string, len(ps.node.params))
for param, index := range ps.node.params {
m := make(map[string]string, len(ps.handler.params))
for param, index := range ps.handler.params {
if value, ok := ps.findParam(index); ok {
m[param] = value
}
Expand All @@ -214,11 +215,11 @@ type Param struct {
}

func (ps Params) Slice() []Param {
if ps.node == nil || len(ps.node.params) == 0 {
if ps.handler == nil || len(ps.handler.params) == 0 {
return nil
}
slice := make([]Param, len(ps.node.params))
for param, index := range ps.node.params {
slice := make([]Param, len(ps.handler.params))
for param, index := range ps.handler.params {
if value, ok := ps.findParam(index); ok {
slice[index] = Param{Key: param, Value: value}
}
Expand Down
3 changes: 2 additions & 1 deletion router.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ func (r *Router) lookup(w http.ResponseWriter, req *http.Request) (HandlerFunc,
handler = node.handlerMap.notAllowed
}

return handler, Params{
return handler.fn, Params{
node: node,
handler: handler,
path: path,
wildcardLen: uint16(wildcardLen),
}
Expand Down
31 changes: 23 additions & 8 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1005,20 +1005,35 @@ func TestMultipleMiddlewaresAndMethodNotAllowed(t *testing.T) {
func TestSameRouteWithDifferentParams(t *testing.T) {
router := New()

router.GET("/:foo", dummyHandler)
require.PanicsWithError(
t,
`routes "/:foo" and "/:bar" have different param names for the same route`,
func() {
router.HEAD("/:bar", dummyHandler)
},
)
router.GET("/:foo", func(w http.ResponseWriter, req Request) error {
require.Equal(t, map[string]string{"foo": "hello"}, req.Params().Map())
return nil
})
router.HEAD("/:bar", func(w http.ResponseWriter, req Request) error {
require.Equal(t, map[string]string{"bar": "hello"}, req.Params().Map())
return nil
})

{
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/hello", nil)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}

{
w := httptest.NewRecorder()
req, _ := http.NewRequest("HEAD", "/hello", nil)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
}

func TestConflictingPlainAndWilcardRoutes(t *testing.T) {
router := New()

router.GET("/", dummyHandler)
router.POST("/*path", dummyHandler)
require.PanicsWithError(
t,
`routes "/" and "/*path" can't both handle GET`,
Expand Down

0 comments on commit 1a74492

Please sign in to comment.