-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdotls_test.go
120 lines (107 loc) · 2.66 KB
/
dotls_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// SPDX-License-Identifier: GPL-3.0-or-later
package dnscore
import (
"bytes"
"context"
"errors"
"net"
"testing"
"github.com/miekg/dns"
"github.com/rbmk-project/common/mocks"
"github.com/stretchr/testify/assert"
)
func TestTransport_dialTLSContext(t *testing.T) {
tests := []struct {
name string
setupTransport func() *Transport
address string
expectedError error
}{
{
name: "Invalid address",
setupTransport: func() *Transport {
return &Transport{}
},
address: "invalid-address",
expectedError: errors.New("address invalid-address: missing port in address"),
},
{
name: "Override DialTLSContext",
setupTransport: func() *Transport {
return &Transport{
DialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{}, nil
},
}
},
address: "example.com:853",
expectedError: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
transport := tt.setupTransport()
ctx := context.Background()
_, err := transport.dialTLSContext(ctx, "tcp", tt.address)
if tt.expectedError != nil {
assert.Error(t, err)
assert.Equal(t, tt.expectedError.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
}
}
func TestTransport_queryTLS(t *testing.T) {
tests := []struct {
name string
setupTransport func() *Transport
expectedError error
}{
{
name: "Successful query",
setupTransport: func() *Transport {
return &Transport{
DialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: (bytes.NewReader(newValidRawRespFrame())).Read,
MockClose: func() error {
return nil
},
}, nil
},
}
},
expectedError: nil,
},
{
name: "Dial failure",
setupTransport: func() *Transport {
return &Transport{
DialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, errors.New("dial failed")
},
}
},
expectedError: errors.New("dial failed"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
transport := tt.setupTransport()
addr := NewServerAddr(ProtocolDoT, "8.8.8.8:853")
query := new(dns.Msg)
query.SetQuestion("example.com.", dns.TypeA)
_, err := transport.queryTLS(context.Background(), addr, query)
if tt.expectedError != nil {
assert.Error(t, err)
assert.Equal(t, tt.expectedError.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
}
}