forked from mholt/caddy-ratelimit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
distributed.go
243 lines (199 loc) · 7.06 KB
/
distributed.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
// Copyright 2021 Matthew Holt
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddyrl
import (
"bytes"
"context"
"encoding/gob"
"net/http"
"path"
"strings"
"sync"
"time"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/certmagic"
"go.uber.org/zap"
)
// DistributedRateLimiting enables and customizes distributed rate limiting.
// It works by writing out the state of all internal rate limiters to storage,
// and reading in the state of all other rate limiters in the cluster, every
// so often.
//
// Distributed rate limiting is not exact like the standard internal rate
// limiting, but it is eventually consistent. Lower (more frequent) sync
// intervals will result in higher consistency and precision, but more I/O
// and CPU overhead.
type DistributedRateLimiting struct {
// How often to sync internal state to storage. Default: 5s
WriteInterval caddy.Duration `json:"write_interval,omitempty"`
// How often to sync other instances' states from storage.
// Default: 5s
ReadInterval caddy.Duration `json:"read_interval,omitempty"`
instanceID string
otherStates []rlState
otherStatesMu sync.RWMutex
}
func (h Handler) syncDistributed(ctx context.Context) {
readTicker := time.NewTicker(time.Duration(h.Distributed.ReadInterval))
writeTicker := time.NewTicker(time.Duration(h.Distributed.WriteInterval))
defer readTicker.Stop()
defer writeTicker.Stop()
for {
select {
case <-readTicker.C:
// get all the latest stored rate limiter states
err := h.syncDistributedRead(ctx)
if err != nil {
h.logger.Error("syncing distributed limiter states", zap.Error(err))
}
case <-writeTicker.C:
// store all current rate limiter states
err := h.syncDistributedWrite(ctx)
if err != nil {
h.logger.Error("distributing internal state", zap.Error(err))
}
case <-ctx.Done():
return
}
}
}
// syncDistributedWrite stores all rate limiter states.
func (h Handler) syncDistributedWrite(ctx context.Context) error {
state := rlState{
Timestamp: now(),
Zones: make(map[string]map[string]rlStateValue),
}
// iterate all rate limit zones
rateLimits.Range(func(zoneName, value interface{}) bool {
zoneNameStr := zoneName.(string)
zoneLimiters := value.(*rateLimitersMap)
state.Zones[zoneNameStr] = zoneLimiters.rlStateForZone(state.Timestamp)
return true
})
return writeRateLimitState(ctx, state, h.Distributed.instanceID, h.storage)
}
func writeRateLimitState(ctx context.Context, state rlState, instanceID string, storage certmagic.Storage) error {
buf := gobBufPool.Get().(*bytes.Buffer)
buf.Reset()
defer gobBufPool.Put(buf)
err := gob.NewEncoder(buf).Encode(state)
if err != nil {
return err
}
err = storage.Store(ctx, path.Join(storagePrefix, instanceID+".rlstate"), buf.Bytes())
if err != nil {
return err
}
return nil
}
// syncDistributedRead loads all rate limiter states from other instances.
func (h Handler) syncDistributedRead(ctx context.Context) error {
instanceFiles, err := h.storage.List(ctx, storagePrefix, false)
if err != nil {
return err
}
if len(instanceFiles) == 0 {
return nil
}
otherStates := make([]rlState, 0, len(instanceFiles)-1)
for _, instanceFile := range instanceFiles {
// skip our own file
if strings.HasSuffix(instanceFile, h.Distributed.instanceID+".rlstate") {
continue
}
encoded, err := h.storage.Load(ctx, instanceFile)
if err != nil {
h.logger.Error("unable to load distributed rate limiter state",
zap.String("key", instanceFile),
zap.Error(err))
continue
}
var state rlState
err = gob.NewDecoder(bytes.NewReader(encoded)).Decode(&state)
if err != nil {
h.logger.Error("corrupted rate limiter state file",
zap.String("key", instanceFile),
zap.Error(err))
continue
}
otherStates = append(otherStates, state)
}
h.Distributed.otherStatesMu.Lock()
h.Distributed.otherStates = otherStates
h.Distributed.otherStatesMu.Unlock()
return nil
}
// distributedRateLimiting enforces limiter (keyed by rlKey) in consideration of all other instances in the cluster.
// If the limit is exceeded, the response is prepared and the relevant error is returned. Otherwise, a reservation
// is made in the local limiter and no error is returned.
func (h Handler) distributedRateLimiting(w http.ResponseWriter, r *http.Request, repl *caddy.Replacer, limiter *ringBufferRateLimiter, rlKey, zoneName string) error {
maxAllowed := limiter.MaxEvents()
window := limiter.Window()
var totalCount int
oldestEvent := now()
h.Distributed.otherStatesMu.RLock()
defer h.Distributed.otherStatesMu.RUnlock()
for _, otherInstanceState := range h.Distributed.otherStates {
// if instance hasn't reported in longer than the window, no point in counting with it
if otherInstanceState.Timestamp.Before(now().Add(-window)) {
continue
}
// if instance has this zone, add last known limiter count
if zone, ok := otherInstanceState.Zones[zoneName]; ok {
// TODO: could probably skew the numbers here based on timestamp and window... perhaps try to predict a better updated count
totalCount += zone[rlKey].Count
if zone[rlKey].OldestEvent.Before(oldestEvent) && zone[rlKey].OldestEvent.After(now().Add(-window)) {
oldestEvent = zone[rlKey].OldestEvent
}
// no point in counting more if we're already over
if totalCount >= maxAllowed {
return h.rateLimitExceeded(w, r, repl, zoneName, oldestEvent.Add(window).Sub(now()))
}
}
}
// add our own internal count (we do this at the end instead of the beginning
// so the critical section over this limiter's lock is smaller), and make the
// reservation if we're within the limit
limiter.mu.Lock()
count, oldestLocalEvent := limiter.countUnsynced(now())
totalCount += count
if oldestLocalEvent.Before(oldestEvent) && oldestLocalEvent.After(now().Add(-window)) {
oldestEvent = oldestLocalEvent
}
if totalCount < maxAllowed {
limiter.reserve()
limiter.mu.Unlock()
return nil
}
limiter.mu.Unlock()
// otherwise, it appears limit has been exceeded
return h.rateLimitExceeded(w, r, repl, zoneName, oldestEvent.Add(window).Sub(now()))
}
type rlStateValue struct {
// Count of events within window
Count int
// Time at which the oldest event in the limiter occurred
OldestEvent time.Time
}
type rlState struct {
// When these values were recorded.
Timestamp time.Time
// Map of zone name to map of all rate limiters in that zone by key to the
// number of events within window and time at which the oldest event
// occurred.
Zones map[string]map[string]rlStateValue
}
var gobBufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
const storagePrefix = "rate_limit/instances"