Skip to content

Commit

Permalink
šŸ  conn, service: support custom "network" for endpoint address resoluā€¦
Browse files Browse the repository at this point in the history
ā€¦tion
  • Loading branch information
database64128 committed Jan 28, 2024
1 parent daffdee commit 5bbf87a
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 54 deletions.
21 changes: 13 additions & 8 deletions conn/addr.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,14 @@ func (a Addr) IPPort() netip.AddrPort {

// ResolveIP resolves a domain name string into an IP address.
//
// The network must be one of "ip", "ip4" or "ip6".
// String representations of IP addresses are not supported.
//
// This function always returns the first IP address returned by the resolver,
// because the resolver takes care of sorting the IP addresses by address family
// availability and preference.
//
// String representations of IP addresses are not supported.
func ResolveIP(ctx context.Context, host string) (netip.Addr, error) {
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", host)
func ResolveIP(ctx context.Context, network, host string) (netip.Addr, error) {
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, host)
if err != nil {
return netip.Addr{}, err
}
Expand All @@ -134,13 +135,15 @@ func ResolveIP(ctx context.Context, host string) (netip.Addr, error) {

// ResolveIP returns the IP address itself or the resolved IP address of the domain name.
//
// The network is only used for domain name resolution and must be one of "ip", "ip4" or "ip6".
//
// If the address is zero value, this method panics.
func (a Addr) ResolveIP(ctx context.Context) (netip.Addr, error) {
func (a Addr) ResolveIP(ctx context.Context, network string) (netip.Addr, error) {
switch a.af {
case addressFamilyNetip:
return a.ip(), nil
case addressFamilyDomain:
return ResolveIP(ctx, a.domain())
return ResolveIP(ctx, network, a.domain())
default:
panic("ResolveIP() called on zero value")
}
Expand All @@ -149,13 +152,15 @@ func (a Addr) ResolveIP(ctx context.Context) (netip.Addr, error) {
// ResolveIPPort returns the IP address itself or the resolved IP address of the domain name
// and the port number as a [netip.AddrPort].
//
// The network is only used for domain name resolution and must be one of "ip", "ip4" or "ip6".
//
// If the address is zero value, this method panics.
func (a Addr) ResolveIPPort(ctx context.Context) (netip.AddrPort, error) {
func (a Addr) ResolveIPPort(ctx context.Context, network string) (netip.AddrPort, error) {
switch a.af {
case addressFamilyNetip:
return a.ipPort(), nil
case addressFamilyDomain:
ip, err := ResolveIP(ctx, a.domain())
ip, err := ResolveIP(ctx, network, a.domain())
if err != nil {
return netip.AddrPort{}, err
}
Expand Down
12 changes: 6 additions & 6 deletions conn/addr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,37 +146,37 @@ func TestAddrIPPort(t *testing.T) {
func TestAddrResolveIP(t *testing.T) {
ctx := context.Background()

ip, err := addrIP.ResolveIP(ctx)
ip, err := addrIP.ResolveIP(ctx, "ip")
if err != nil {
t.Fatal(err)
}
if ip != addrIPAddr {
t.Errorf("addrIP.ResolveIP() returned %s, expected %s.", ip, addrIPAddr)
}

ip, err = addrDomain.ResolveIP(ctx)
ip, err = addrDomain.ResolveIP(ctx, "ip")
if err != nil {
t.Fatal(err)
}
if !ip.IsValid() {
t.Error("addrDomain.ResolveIP() returned invalid IP address.")
}

assertPanic(t, func() { addrZero.ResolveIP(ctx) })
assertPanic(t, func() { addrZero.ResolveIP(ctx, "ip") })
}

func TestAddrResolveIPPort(t *testing.T) {
ctx := context.Background()

ipPort, err := addrIP.ResolveIPPort(ctx)
ipPort, err := addrIP.ResolveIPPort(ctx, "ip")
if err != nil {
t.Fatal(err)
}
if ipPort != addrIPAddrPort {
t.Errorf("addrIP.ResolveIPPort() returned %s, expected %s.", ipPort, addrIPAddrPort)
}

ipPort, err = addrDomain.ResolveIPPort(ctx)
ipPort, err = addrDomain.ResolveIPPort(ctx, "ip")
if err != nil {
t.Fatal(err)
}
Expand All @@ -187,7 +187,7 @@ func TestAddrResolveIPPort(t *testing.T) {
t.Errorf("addrDomain.ResolveIPPort(false) returned %s, expected port %d.", ipPort, addrDomainPort)
}

assertPanic(t, func() { addrZero.ResolveIPPort(ctx) })
assertPanic(t, func() { addrZero.ResolveIPPort(ctx, "ip") })
}

func TestAddrHost(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions docs/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"proxyPSK": "sAe5RvzLJ3Q0Ll88QRM1N01dYk83Q4y0rXMP1i4rDmI=",
"proxyFwmark": 0,
"proxyTrafficClass": 0,
"wgEndpointNetwork": "",
"wgEndpoint": "[::1]:20221",
"wgConnListenNetwork": "",
"wgConnListenAddress": "",
Expand All @@ -27,6 +28,7 @@
"wgListen": ":20222",
"wgFwmark": 0,
"wgTrafficClass": 0,
"proxyEndpointNetwork": "",
"proxyEndpoint": "[2001:db8:1f74:3c86:aef9:a75:5d2a:425e]:20220",
"proxyConnListenNetwork": "",
"proxyConnListenAddress": "",
Expand Down
22 changes: 17 additions & 5 deletions service/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ type ClientConfig struct {
WgListenAddress string `json:"wgListen"`
WgFwmark int `json:"wgFwmark"`
WgTrafficClass int `json:"wgTrafficClass"`
ProxyEndpoint conn.Addr `json:"proxyEndpoint"`
ProxyEndpointNetwork string `json:"proxyEndpointNetwork"`
ProxyEndpointAddress conn.Addr `json:"proxyEndpoint"`
ProxyConnListenNetwork string `json:"proxyConnListenNetwork"`
ProxyConnListenAddress string `json:"proxyConnListenAddress"`
ProxyMode string `json:"proxyMode"`
Expand Down Expand Up @@ -83,6 +84,7 @@ type client struct {
maxProxyPacketSizev6 int
wgTunnelMTU int
wgTunnelMTUv6 int
proxyNetwork string
proxyAddr conn.Addr
handler packet.Handler
logger *zap.Logger
Expand Down Expand Up @@ -114,6 +116,15 @@ func (cc *ClientConfig) Client(logger *zap.Logger, listenConfigCache conn.Listen
return nil, fmt.Errorf("invalid wgListenNetwork: %s", cc.WgListenNetwork)
}

// Check ProxyEndpointNetwork.
switch cc.ProxyEndpointNetwork {
case "":
cc.ProxyEndpointNetwork = "ip"
case "ip", "ip4", "ip6":
default:
return nil, fmt.Errorf("invalid proxyEndpointNetwork: %s", cc.ProxyEndpointNetwork)
}

// Check ProxyConnListenNetwork.
switch cc.ProxyConnListenNetwork {
case "":
Expand Down Expand Up @@ -141,8 +152,8 @@ func (cc *ClientConfig) Client(logger *zap.Logger, listenConfigCache conn.Listen
wgTunnelMTUv6 := getWgTunnelMTUForHandler(handler, maxProxyPacketSizev6)

// Use IPv6 values if the proxy endpoint is an IPv6 address.
if cc.ProxyEndpoint.IsIP() {
if ip := cc.ProxyEndpoint.IP(); !ip.Is4() && !ip.Is4In6() {
if cc.ProxyEndpointAddress.IsIP() {
if ip := cc.ProxyEndpointAddress.IP(); !ip.Is4() && !ip.Is4In6() {
maxProxyPacketSize = maxProxyPacketSizev6
wgTunnelMTU = wgTunnelMTUv6
}
Expand All @@ -161,7 +172,8 @@ func (cc *ClientConfig) Client(logger *zap.Logger, listenConfigCache conn.Listen
maxProxyPacketSizev6: maxProxyPacketSizev6,
wgTunnelMTU: wgTunnelMTU,
wgTunnelMTUv6: wgTunnelMTUv6,
proxyAddr: cc.ProxyEndpoint,
proxyNetwork: cc.ProxyEndpointNetwork,
proxyAddr: cc.ProxyEndpointAddress,
handler: handler,
logger: logger,
wgConnListenConfig: listenConfigCache.Get(conn.ListenerSocketOptions{
Expand Down Expand Up @@ -342,7 +354,7 @@ func (c *client) recvFromWgConnGeneric(ctx context.Context, wgConn *net.UDPConn)
c.wg.Done()
}()

proxyAddrPort, err := c.proxyAddr.ResolveIPPort(ctx)
proxyAddrPort, err := c.proxyAddr.ResolveIPPort(ctx, c.proxyNetwork)
if err != nil {
c.logger.Warn("Failed to resolve proxy address for new session",
zap.String("client", c.name),
Expand Down
2 changes: 1 addition & 1 deletion service/client_mmsg.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func (c *client) recvFromWgConnRecvmmsg(ctx context.Context, wgConn *conn.MmsgRC
c.wg.Done()
}()

proxyAddrPort, err := c.proxyAddr.ResolveIPPort(ctx)
proxyAddrPort, err := c.proxyAddr.ResolveIPPort(ctx, c.proxyNetwork)
if err != nil {
c.logger.Warn("Failed to resolve proxy address for new session",
zap.String("client", c.name),
Expand Down
60 changes: 30 additions & 30 deletions service/client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func testClientServerHandshake(t *testing.T, ctx context.Context, serverConfig S
if err != nil {
t.Fatal(err)
}
serverConn, err := conn.DefaultUDPClientListenConfig.ListenUDP(ctx, "udp", serverConfig.WgEndpoint.String())
serverConn, err := conn.DefaultUDPClientListenConfig.ListenUDP(ctx, "udp", serverConfig.WgEndpointAddress.String())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -110,17 +110,17 @@ func TestClientServerHandshakeZeroOverhead(t *testing.T) {
ProxyListenAddress: ":20220",
ProxyMode: "zero-overhead",
ProxyPSK: psk,
WgEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20221)),
WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20221)),
MTU: 1500,
}

clientConfig := ClientConfig{
Name: "wg0",
WgListenAddress: ":20222",
ProxyEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20220)),
ProxyMode: "zero-overhead",
ProxyPSK: psk,
MTU: 1500,
Name: "wg0",
WgListenAddress: ":20222",
ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20220)),
ProxyMode: "zero-overhead",
ProxyPSK: psk,
MTU: 1500,
}

testClientServerHandshake(t, context.Background(), serverConfig, clientConfig)
Expand All @@ -134,17 +134,17 @@ func TestClientServerHandshakeParanoid(t *testing.T) {
ProxyListenAddress: ":20223",
ProxyMode: "paranoid",
ProxyPSK: psk,
WgEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20224)),
WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20224)),
MTU: 1500,
}

clientConfig := ClientConfig{
Name: "wg0",
WgListenAddress: ":20225",
ProxyEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20223)),
ProxyMode: "paranoid",
ProxyPSK: psk,
MTU: 1500,
Name: "wg0",
WgListenAddress: ":20225",
ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20223)),
ProxyMode: "paranoid",
ProxyPSK: psk,
MTU: 1500,
}

testClientServerHandshake(t, context.Background(), serverConfig, clientConfig)
Expand Down Expand Up @@ -185,7 +185,7 @@ func testClientServerDataPackets(t *testing.T, ctx context.Context, serverConfig
if err != nil {
t.Fatal(err)
}
serverConn, err := conn.DefaultUDPClientListenConfig.ListenUDP(ctx, "udp", serverConfig.WgEndpoint.String())
serverConn, err := conn.DefaultUDPClientListenConfig.ListenUDP(ctx, "udp", serverConfig.WgEndpointAddress.String())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -245,17 +245,17 @@ func TestClientServerDataPacketsZeroOverhead(t *testing.T) {
ProxyListenAddress: ":20230",
ProxyMode: "zero-overhead",
ProxyPSK: psk,
WgEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20231)),
WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20231)),
MTU: 1500,
}

clientConfig := ClientConfig{
Name: "wg0",
WgListenAddress: ":20232",
ProxyEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20230)),
ProxyMode: "zero-overhead",
ProxyPSK: psk,
MTU: 1500,
Name: "wg0",
WgListenAddress: ":20232",
ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20230)),
ProxyMode: "zero-overhead",
ProxyPSK: psk,
MTU: 1500,
}

testClientServerDataPackets(t, context.Background(), serverConfig, clientConfig)
Expand All @@ -269,17 +269,17 @@ func TestClientServerDataPacketsParanoid(t *testing.T) {
ProxyListenAddress: ":20233",
ProxyMode: "paranoid",
ProxyPSK: psk,
WgEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20234)),
WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20234)),
MTU: 1500,
}

clientConfig := ClientConfig{
Name: "wg0",
WgListenAddress: ":20235",
ProxyEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20233)),
ProxyMode: "paranoid",
ProxyPSK: psk,
MTU: 1500,
Name: "wg0",
WgListenAddress: ":20235",
ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20233)),
ProxyMode: "paranoid",
ProxyPSK: psk,
MTU: 1500,
}

testClientServerDataPackets(t, context.Background(), serverConfig, clientConfig)
Expand Down
18 changes: 15 additions & 3 deletions service/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ type ServerConfig struct {
ProxyPSK []byte `json:"proxyPSK"`
ProxyFwmark int `json:"proxyFwmark"`
ProxyTrafficClass int `json:"proxyTrafficClass"`
WgEndpoint conn.Addr `json:"wgEndpoint"`
WgEndpointNetwork string `json:"wgEndpointNetwork"`
WgEndpointAddress conn.Addr `json:"wgEndpoint"`
WgConnListenNetwork string `json:"wgConnListenNetwork"`
WgConnListenAddress string `json:"wgConnListenAddress"`
WgFwmark int `json:"wgFwmark"`
Expand Down Expand Up @@ -83,6 +84,7 @@ type server struct {
maxProxyPacketSizev6 int
wgTunnelMTUv4 int
wgTunnelMTUv6 int
wgNetwork string
wgAddr conn.Addr
handler packet.Handler
logger *zap.Logger
Expand Down Expand Up @@ -114,6 +116,15 @@ func (sc *ServerConfig) Server(logger *zap.Logger, listenConfigCache conn.Listen
return nil, fmt.Errorf("invalid proxyListenNetwork: %s", sc.ProxyListenNetwork)
}

// Check WgEndpointNetwork.
switch sc.WgEndpointNetwork {
case "":
sc.WgEndpointNetwork = "ip"
case "ip", "ip4", "ip6":
default:
return nil, fmt.Errorf("invalid wgEndpointNetwork: %s", sc.WgEndpointNetwork)
}

// Check WgConnListenNetwork.
switch sc.WgConnListenNetwork {
case "":
Expand Down Expand Up @@ -153,7 +164,8 @@ func (sc *ServerConfig) Server(logger *zap.Logger, listenConfigCache conn.Listen
maxProxyPacketSizev6: maxProxyPacketSizev6,
wgTunnelMTUv4: wgTunnelMTUv4,
wgTunnelMTUv6: wgTunnelMTUv6,
wgAddr: sc.WgEndpoint,
wgNetwork: sc.WgEndpointNetwork,
wgAddr: sc.WgEndpointAddress,
handler: handler,
logger: logger,
proxyConnListenConfig: listenConfigCache.Get(conn.ListenerSocketOptions{
Expand Down Expand Up @@ -332,7 +344,7 @@ func (s *server) recvFromProxyConnGeneric(ctx context.Context, proxyConn *net.UD
s.wg.Done()
}()

wgAddrPort, err := s.wgAddr.ResolveIPPort(ctx)
wgAddrPort, err := s.wgAddr.ResolveIPPort(ctx, s.wgNetwork)
if err != nil {
s.logger.Warn("Failed to resolve wg address for new session",
zap.String("server", s.name),
Expand Down
2 changes: 1 addition & 1 deletion service/server_mmsg.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func (s *server) recvFromProxyConnRecvmmsg(ctx context.Context, proxyConn *conn.
s.wg.Done()
}()

wgAddrPort, err := s.wgAddr.ResolveIPPort(ctx)
wgAddrPort, err := s.wgAddr.ResolveIPPort(ctx, s.wgNetwork)
if err != nil {
s.logger.Warn("Failed to resolve wgAddr",
zap.String("server", s.name),
Expand Down

0 comments on commit 5bbf87a

Please sign in to comment.