-
Notifications
You must be signed in to change notification settings - Fork 0
/
router.go
152 lines (127 loc) · 3.69 KB
/
router.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package httpRouting
import (
"context"
"fmt"
"net/http"
"regexp"
"strconv"
"strings"
)
type router struct {
routes map[string][]route
corsHeaders map[string]string
}
type ContextKey string
func NewRouterBuilder() *router {
return &router{
routes: make(map[string][]route),
corsHeaders: make(map[string]string),
}
}
func (r *router) SetAllowOrigin(o string) *router {
o = strings.TrimSpace(o)
if o != "" {
r.corsHeaders["Access-Control-Allow-Origin"] = o
}
return r
}
func (r *router) SetAllowMethods(m []string) *router {
if len(m) > 0 {
r.corsHeaders["Access-Control-Allow-Methods"] = strings.Join(m, ", ")
}
return r
}
func (r *router) SetAllowHeaders(h []string) *router {
if len(h) > 0 {
r.corsHeaders["Access-Control-Allow-Headers"] = strings.Join(h, ", ")
}
return r
}
func (r *router) SetCredantials(c bool) *router {
if c {
r.corsHeaders["Access-Control-Allow-Credentials"] = "true"
}
return r
}
// NewRouter creates new route and appends it to Router,
// method specifies the method that is allowed,
// regexp must contain a named group,
// parsed value will be accessable through handler's context via GetField function
//
// regexp example -> https://regex101.com/r/84S9iL/1
func (r *router) NewRoute(method, regexpString string, handler http.HandlerFunc, middlewareBefore ...Middleware) {
regex := regexp.MustCompile("^" + regexpString + "$")
method = strings.ToUpper(method)
r.routes[method] = append(r.routes[method], route{
regex,
handler,
middlewareBefore,
make([]Middleware, 0),
})
}
// If there is a exact match Serve redirects to the associated handler function.
//
// When matched, regular expression groups are used as key value pairs
// accessible in handler's context via GetField function
func (r *router) Serve(w http.ResponseWriter, req *http.Request) {
r.serve(w, req)
}
// Same usage as in Serve function, but also adds specified CORS headers
func (r *router) ServeWithCORS() http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
for header, value := range r.corsHeaders {
w.Header().Set(header, value)
}
if req.Method == "OPTIONS" {
return
}
r.serve(w, req)
}
}
func (r *router) serve(w http.ResponseWriter, req *http.Request) {
for _, route := range r.routes[strings.ToUpper(req.Method)] {
match := route.regex.FindStringSubmatch(req.URL.Path)
if len(match) > 0 {
matchMap := make(map[string]string)
groupName := route.regex.SubexpNames()
// map group name(key) to submatched result
// these arrays have one to one relationship
for i := 1; i < len(match); i++ {
matchMap[groupName[i]] = match[i]
}
ctx := context.WithValue(req.Context(), ContextKey("requestParams"), matchMap)
req = req.WithContext(ctx)
handler := route.handler
for i := len(route.middlewareBefore) - 1; i >= 0; i-- {
handler = route.middlewareBefore[i](handler)
}
handler.ServeHTTP(w, req)
return
}
}
w.WriteHeader(http.StatusNotFound)
}
// Returns the string value of the given key from matched URL variables
func GetRequestParamString(r *http.Request, name string) (string, error) {
fields, ok := r.Context().Value(ContextKey("requestParams")).(map[string]string)
if !ok {
return "", fmt.Errorf("internal error: no fileds in context")
}
field, exist := fields[name]
if !exist {
return "", fmt.Errorf("no such variable in request: %v", name)
}
return field, nil
}
// Returns the integer value of the given key from matched URL variables
func GetRequestParamInt(r *http.Request, name string) (int, error) {
field, err := GetRequestParamString(r, name)
if err != nil {
return 0, err
}
fieldInt, err := strconv.Atoi(field)
if err != nil {
return 0, err
}
return fieldInt, nil
}