-
Notifications
You must be signed in to change notification settings - Fork 0
/
split.go
216 lines (195 loc) · 5.72 KB
/
split.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
// Package split is a CoreDNS plugin that prints "example" to stdout on every packet received.
//
// It serves as an example CoreDNS plugin with numerous code comments.
package split
import (
"context"
"net"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Define log to be a logger with the plugin name in it. This way we can just use log.Info and
// friends to log.
var log = clog.NewWithPlugin("split")
const noFallback = "split-no-fallback"
func isNoFallback(ctx context.Context) bool {
if ctx == nil {
return false
}
if v, ok := ctx.Value(noFallback).(bool); ok {
return v
}
return false
}
// Split is an example plugin to show how to write a plugin.
type Split struct {
Next plugin.Handler
Rules []Rule
}
type Rule struct {
Zones []string
Networks []Network
Fallback net.IP
}
type Network struct {
RecordNetwork *net.IPNet
Allowed []*net.IPNet
}
// ServeDNS implements the plugin.Handler interface. This method gets called when example is used
// in a Server.
func (s Split) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
// This function could be simpler. I.e. just fmt.Println("example") here, but we want to show
// a slightly more complex example as to make this more interesting.
// Here we wrap the dns.ResponseWriter in a new ResponseWriter and call the next plugin, when the
// answer comes back, it will print "example".
// Debug log that we've seen the query. This will only be shown when the debug plugin is loaded.
log.Debug("Received response")
// Wrap.
pw := s.NewResponsePrinter(ctx, w, r)
// Export metric with the server label set to the current server handling the request.
requestCount.WithLabelValues(metrics.WithServer(ctx)).Inc()
// Call next plugin (if any).
return plugin.NextOrFailure(s.Name(), s.Next, ctx, pw, r)
}
// Name implements the Handler interface.
func (s Split) Name() string { return "split" }
// ResponsePrinter wrap a dns.ResponseWriter and will write example to standard output when WriteMsg is called.
type ResponsePrinter struct {
dns.ResponseWriter
ctx context.Context
state request.Request
r *dns.Msg
src net.IP
rules []Rule
}
// NewResponsePrinter returns ResponseWriter.
func (s Split) NewResponsePrinter(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) *ResponsePrinter {
state := request.Request{W: w, Req: r}
ip := net.ParseIP(state.IP())
return &ResponsePrinter{ctx: ctx, ResponseWriter: w, r: r, src: ip, rules: s.Rules, state: state}
}
// WriteMsg calls the underlying ResponseWriter's WriteMsg method and prints "example" to standard output.
func (r *ResponsePrinter) WriteMsg(res *dns.Msg) error {
filter := func(rec *dns.A) (rule Rule, allowed, match bool) {
for _, v := range r.rules {
zone := plugin.Zones(v.Zones).Matches(r.state.Name())
if zone == "" {
continue
}
rule = v
break
}
var net *Network
for _, vv := range rule.Networks {
if vv.RecordNetwork.Contains(rec.A) {
net = &vv
break
}
}
if net == nil {
return rule, true, false
}
for _, vv := range net.Allowed {
if vv.Contains(r.src) {
return rule, true, true
}
}
return rule, false, true
}
var (
rule Rule
answers []dns.RR
netAnswers []dns.RR
)
for _, v := range res.Answer {
switch rec := v.(type) {
case *dns.A:
var allowed, match bool
rule, allowed, match = filter(rec)
if !match {
answers = append(answers, v)
continue
}
if allowed {
answers = append(answers, v)
netAnswers = append(netAnswers, v)
continue
}
log.Infof("request source %s: %s: filtering %s", r.src.String(), rec.Hdr.Name, rec.A)
case *dns.CNAME:
res, err := r.query(rec.Target)
if err != nil {
log.Errorf("error querying %s: %s", rec.Target, err)
continue
}
if res == nil || len(res.Answer) == 0 {
log.Debugf("no answers for %s", rec.Target)
continue
}
answers = append(answers, v)
case *dns.SRV:
res, err := r.query(rec.Target)
if err != nil {
log.Errorf("error querying %s: %s", rec.Target, err)
continue
}
if res == nil || len(res.Answer) == 0 {
log.Debugf("no answers for %s", rec.Target)
continue
}
answers = append(answers, v)
case *dns.PTR:
a, err := r.query(rec.Ptr)
if err != nil {
log.Errorf("error querying %s: %s", rec.Ptr, err)
continue
}
if res == nil || len(a.Answer) == 0 {
log.Debugf("no answer for %s", rec.Ptr)
continue
}
answers = append(answers, v)
default:
return r.ResponseWriter.WriteMsg(res)
}
}
if len(netAnswers) != 0 {
res.Answer = netAnswers
} else {
res.Answer = answers
}
if len(res.Answer) != 0 || len(rule.Zones) == 0 {
return r.ResponseWriter.WriteMsg(res)
}
if isNoFallback(r.ctx) {
log.Debugf("no fallback requested for %s", r.state.Name())
return r.ResponseWriter.WriteMsg(res)
}
if rule.Fallback == nil {
log.Debugf("no fallback configured for zones %v", rule.Zones)
return r.ResponseWriter.WriteMsg(res)
}
log.Debugf("request source %s: %s: using fallback %s", r.src.String(), r.state.Name(), rule.Fallback)
c := new(dns.Client)
req := r.state.Req.Copy()
req.Id = dns.Id()
in, _, err := c.Exchange(req, rule.Fallback.String()+":53")
if err != nil {
return err
}
res.Answer = append(res.Answer, in.Answer...)
return r.ResponseWriter.WriteMsg(res)
}
func (r *ResponsePrinter) query(name string) (*dns.Msg, error) {
log.Debugf("internally querying %s", name)
ctx := context.WithValue(r.ctx, noFallback, true)
res, err := upstream.New().Lookup(ctx, r.state, name, dns.TypeA)
if err != nil {
return nil, err
}
return res, nil
}