From 5bbf87a8c65f586e6cd6e477f110f708b33fca75 Mon Sep 17 00:00:00 2001 From: database64128 Date: Sun, 28 Jan 2024 16:18:57 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8D=A0=20conn,=20service:=20support=20cus?= =?UTF-8?q?tom=20"network"=20for=20endpoint=20address=20resolution?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- conn/addr.go | 21 +++++++----- conn/addr_test.go | 12 +++---- docs/config.json | 2 ++ service/client.go | 22 ++++++++++--- service/client_mmsg.go | 2 +- service/client_server_test.go | 60 +++++++++++++++++------------------ service/server.go | 18 +++++++++-- service/server_mmsg.go | 2 +- 8 files changed, 85 insertions(+), 54 deletions(-) diff --git a/conn/addr.go b/conn/addr.go index 5708a30..e93db1a 100644 --- a/conn/addr.go +++ b/conn/addr.go @@ -119,13 +119,14 @@ func (a Addr) IPPort() netip.AddrPort { // ResolveIP resolves a domain name string into an IP address. // +// The network must be one of "ip", "ip4" or "ip6". +// String representations of IP addresses are not supported. +// // This function always returns the first IP address returned by the resolver, // because the resolver takes care of sorting the IP addresses by address family // availability and preference. -// -// String representations of IP addresses are not supported. -func ResolveIP(ctx context.Context, host string) (netip.Addr, error) { - ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", host) +func ResolveIP(ctx context.Context, network, host string) (netip.Addr, error) { + ips, err := net.DefaultResolver.LookupNetIP(ctx, network, host) if err != nil { return netip.Addr{}, err } @@ -134,13 +135,15 @@ func ResolveIP(ctx context.Context, host string) (netip.Addr, error) { // ResolveIP returns the IP address itself or the resolved IP address of the domain name. // +// The network is only used for domain name resolution and must be one of "ip", "ip4" or "ip6". +// // If the address is zero value, this method panics. -func (a Addr) ResolveIP(ctx context.Context) (netip.Addr, error) { +func (a Addr) ResolveIP(ctx context.Context, network string) (netip.Addr, error) { switch a.af { case addressFamilyNetip: return a.ip(), nil case addressFamilyDomain: - return ResolveIP(ctx, a.domain()) + return ResolveIP(ctx, network, a.domain()) default: panic("ResolveIP() called on zero value") } @@ -149,13 +152,15 @@ func (a Addr) ResolveIP(ctx context.Context) (netip.Addr, error) { // ResolveIPPort returns the IP address itself or the resolved IP address of the domain name // and the port number as a [netip.AddrPort]. // +// The network is only used for domain name resolution and must be one of "ip", "ip4" or "ip6". +// // If the address is zero value, this method panics. -func (a Addr) ResolveIPPort(ctx context.Context) (netip.AddrPort, error) { +func (a Addr) ResolveIPPort(ctx context.Context, network string) (netip.AddrPort, error) { switch a.af { case addressFamilyNetip: return a.ipPort(), nil case addressFamilyDomain: - ip, err := ResolveIP(ctx, a.domain()) + ip, err := ResolveIP(ctx, network, a.domain()) if err != nil { return netip.AddrPort{}, err } diff --git a/conn/addr_test.go b/conn/addr_test.go index 8829b8c..2249547 100644 --- a/conn/addr_test.go +++ b/conn/addr_test.go @@ -146,7 +146,7 @@ func TestAddrIPPort(t *testing.T) { func TestAddrResolveIP(t *testing.T) { ctx := context.Background() - ip, err := addrIP.ResolveIP(ctx) + ip, err := addrIP.ResolveIP(ctx, "ip") if err != nil { t.Fatal(err) } @@ -154,7 +154,7 @@ func TestAddrResolveIP(t *testing.T) { t.Errorf("addrIP.ResolveIP() returned %s, expected %s.", ip, addrIPAddr) } - ip, err = addrDomain.ResolveIP(ctx) + ip, err = addrDomain.ResolveIP(ctx, "ip") if err != nil { t.Fatal(err) } @@ -162,13 +162,13 @@ func TestAddrResolveIP(t *testing.T) { t.Error("addrDomain.ResolveIP() returned invalid IP address.") } - assertPanic(t, func() { addrZero.ResolveIP(ctx) }) + assertPanic(t, func() { addrZero.ResolveIP(ctx, "ip") }) } func TestAddrResolveIPPort(t *testing.T) { ctx := context.Background() - ipPort, err := addrIP.ResolveIPPort(ctx) + ipPort, err := addrIP.ResolveIPPort(ctx, "ip") if err != nil { t.Fatal(err) } @@ -176,7 +176,7 @@ func TestAddrResolveIPPort(t *testing.T) { t.Errorf("addrIP.ResolveIPPort() returned %s, expected %s.", ipPort, addrIPAddrPort) } - ipPort, err = addrDomain.ResolveIPPort(ctx) + ipPort, err = addrDomain.ResolveIPPort(ctx, "ip") if err != nil { t.Fatal(err) } @@ -187,7 +187,7 @@ func TestAddrResolveIPPort(t *testing.T) { t.Errorf("addrDomain.ResolveIPPort(false) returned %s, expected port %d.", ipPort, addrDomainPort) } - assertPanic(t, func() { addrZero.ResolveIPPort(ctx) }) + assertPanic(t, func() { addrZero.ResolveIPPort(ctx, "ip") }) } func TestAddrHost(t *testing.T) { diff --git a/docs/config.json b/docs/config.json index 5bce9af..fad746a 100644 --- a/docs/config.json +++ b/docs/config.json @@ -8,6 +8,7 @@ "proxyPSK": "sAe5RvzLJ3Q0Ll88QRM1N01dYk83Q4y0rXMP1i4rDmI=", "proxyFwmark": 0, "proxyTrafficClass": 0, + "wgEndpointNetwork": "", "wgEndpoint": "[::1]:20221", "wgConnListenNetwork": "", "wgConnListenAddress": "", @@ -27,6 +28,7 @@ "wgListen": ":20222", "wgFwmark": 0, "wgTrafficClass": 0, + "proxyEndpointNetwork": "", "proxyEndpoint": "[2001:db8:1f74:3c86:aef9:a75:5d2a:425e]:20220", "proxyConnListenNetwork": "", "proxyConnListenAddress": "", diff --git a/service/client.go b/service/client.go index 8bc909b..7721a94 100644 --- a/service/client.go +++ b/service/client.go @@ -26,7 +26,8 @@ type ClientConfig struct { WgListenAddress string `json:"wgListen"` WgFwmark int `json:"wgFwmark"` WgTrafficClass int `json:"wgTrafficClass"` - ProxyEndpoint conn.Addr `json:"proxyEndpoint"` + ProxyEndpointNetwork string `json:"proxyEndpointNetwork"` + ProxyEndpointAddress conn.Addr `json:"proxyEndpoint"` ProxyConnListenNetwork string `json:"proxyConnListenNetwork"` ProxyConnListenAddress string `json:"proxyConnListenAddress"` ProxyMode string `json:"proxyMode"` @@ -83,6 +84,7 @@ type client struct { maxProxyPacketSizev6 int wgTunnelMTU int wgTunnelMTUv6 int + proxyNetwork string proxyAddr conn.Addr handler packet.Handler logger *zap.Logger @@ -114,6 +116,15 @@ func (cc *ClientConfig) Client(logger *zap.Logger, listenConfigCache conn.Listen return nil, fmt.Errorf("invalid wgListenNetwork: %s", cc.WgListenNetwork) } + // Check ProxyEndpointNetwork. + switch cc.ProxyEndpointNetwork { + case "": + cc.ProxyEndpointNetwork = "ip" + case "ip", "ip4", "ip6": + default: + return nil, fmt.Errorf("invalid proxyEndpointNetwork: %s", cc.ProxyEndpointNetwork) + } + // Check ProxyConnListenNetwork. switch cc.ProxyConnListenNetwork { case "": @@ -141,8 +152,8 @@ func (cc *ClientConfig) Client(logger *zap.Logger, listenConfigCache conn.Listen wgTunnelMTUv6 := getWgTunnelMTUForHandler(handler, maxProxyPacketSizev6) // Use IPv6 values if the proxy endpoint is an IPv6 address. - if cc.ProxyEndpoint.IsIP() { - if ip := cc.ProxyEndpoint.IP(); !ip.Is4() && !ip.Is4In6() { + if cc.ProxyEndpointAddress.IsIP() { + if ip := cc.ProxyEndpointAddress.IP(); !ip.Is4() && !ip.Is4In6() { maxProxyPacketSize = maxProxyPacketSizev6 wgTunnelMTU = wgTunnelMTUv6 } @@ -161,7 +172,8 @@ func (cc *ClientConfig) Client(logger *zap.Logger, listenConfigCache conn.Listen maxProxyPacketSizev6: maxProxyPacketSizev6, wgTunnelMTU: wgTunnelMTU, wgTunnelMTUv6: wgTunnelMTUv6, - proxyAddr: cc.ProxyEndpoint, + proxyNetwork: cc.ProxyEndpointNetwork, + proxyAddr: cc.ProxyEndpointAddress, handler: handler, logger: logger, wgConnListenConfig: listenConfigCache.Get(conn.ListenerSocketOptions{ @@ -342,7 +354,7 @@ func (c *client) recvFromWgConnGeneric(ctx context.Context, wgConn *net.UDPConn) c.wg.Done() }() - proxyAddrPort, err := c.proxyAddr.ResolveIPPort(ctx) + proxyAddrPort, err := c.proxyAddr.ResolveIPPort(ctx, c.proxyNetwork) if err != nil { c.logger.Warn("Failed to resolve proxy address for new session", zap.String("client", c.name), diff --git a/service/client_mmsg.go b/service/client_mmsg.go index 6a7acab..b7a9537 100644 --- a/service/client_mmsg.go +++ b/service/client_mmsg.go @@ -242,7 +242,7 @@ func (c *client) recvFromWgConnRecvmmsg(ctx context.Context, wgConn *conn.MmsgRC c.wg.Done() }() - proxyAddrPort, err := c.proxyAddr.ResolveIPPort(ctx) + proxyAddrPort, err := c.proxyAddr.ResolveIPPort(ctx, c.proxyNetwork) if err != nil { c.logger.Warn("Failed to resolve proxy address for new session", zap.String("client", c.name), diff --git a/service/client_server_test.go b/service/client_server_test.go index d528d32..cda53a0 100644 --- a/service/client_server_test.go +++ b/service/client_server_test.go @@ -62,7 +62,7 @@ func testClientServerHandshake(t *testing.T, ctx context.Context, serverConfig S if err != nil { t.Fatal(err) } - serverConn, err := conn.DefaultUDPClientListenConfig.ListenUDP(ctx, "udp", serverConfig.WgEndpoint.String()) + serverConn, err := conn.DefaultUDPClientListenConfig.ListenUDP(ctx, "udp", serverConfig.WgEndpointAddress.String()) if err != nil { t.Fatal(err) } @@ -110,17 +110,17 @@ func TestClientServerHandshakeZeroOverhead(t *testing.T) { ProxyListenAddress: ":20220", ProxyMode: "zero-overhead", ProxyPSK: psk, - WgEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20221)), + WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20221)), MTU: 1500, } clientConfig := ClientConfig{ - Name: "wg0", - WgListenAddress: ":20222", - ProxyEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20220)), - ProxyMode: "zero-overhead", - ProxyPSK: psk, - MTU: 1500, + Name: "wg0", + WgListenAddress: ":20222", + ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20220)), + ProxyMode: "zero-overhead", + ProxyPSK: psk, + MTU: 1500, } testClientServerHandshake(t, context.Background(), serverConfig, clientConfig) @@ -134,17 +134,17 @@ func TestClientServerHandshakeParanoid(t *testing.T) { ProxyListenAddress: ":20223", ProxyMode: "paranoid", ProxyPSK: psk, - WgEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20224)), + WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20224)), MTU: 1500, } clientConfig := ClientConfig{ - Name: "wg0", - WgListenAddress: ":20225", - ProxyEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20223)), - ProxyMode: "paranoid", - ProxyPSK: psk, - MTU: 1500, + Name: "wg0", + WgListenAddress: ":20225", + ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20223)), + ProxyMode: "paranoid", + ProxyPSK: psk, + MTU: 1500, } testClientServerHandshake(t, context.Background(), serverConfig, clientConfig) @@ -185,7 +185,7 @@ func testClientServerDataPackets(t *testing.T, ctx context.Context, serverConfig if err != nil { t.Fatal(err) } - serverConn, err := conn.DefaultUDPClientListenConfig.ListenUDP(ctx, "udp", serverConfig.WgEndpoint.String()) + serverConn, err := conn.DefaultUDPClientListenConfig.ListenUDP(ctx, "udp", serverConfig.WgEndpointAddress.String()) if err != nil { t.Fatal(err) } @@ -245,17 +245,17 @@ func TestClientServerDataPacketsZeroOverhead(t *testing.T) { ProxyListenAddress: ":20230", ProxyMode: "zero-overhead", ProxyPSK: psk, - WgEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20231)), + WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20231)), MTU: 1500, } clientConfig := ClientConfig{ - Name: "wg0", - WgListenAddress: ":20232", - ProxyEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20230)), - ProxyMode: "zero-overhead", - ProxyPSK: psk, - MTU: 1500, + Name: "wg0", + WgListenAddress: ":20232", + ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20230)), + ProxyMode: "zero-overhead", + ProxyPSK: psk, + MTU: 1500, } testClientServerDataPackets(t, context.Background(), serverConfig, clientConfig) @@ -269,17 +269,17 @@ func TestClientServerDataPacketsParanoid(t *testing.T) { ProxyListenAddress: ":20233", ProxyMode: "paranoid", ProxyPSK: psk, - WgEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20234)), + WgEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20234)), MTU: 1500, } clientConfig := ClientConfig{ - Name: "wg0", - WgListenAddress: ":20235", - ProxyEndpoint: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20233)), - ProxyMode: "paranoid", - ProxyPSK: psk, - MTU: 1500, + Name: "wg0", + WgListenAddress: ":20235", + ProxyEndpointAddress: conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Loopback(), 20233)), + ProxyMode: "paranoid", + ProxyPSK: psk, + MTU: 1500, } testClientServerDataPackets(t, context.Background(), serverConfig, clientConfig) diff --git a/service/server.go b/service/server.go index 044add0..95abc2b 100644 --- a/service/server.go +++ b/service/server.go @@ -28,7 +28,8 @@ type ServerConfig struct { ProxyPSK []byte `json:"proxyPSK"` ProxyFwmark int `json:"proxyFwmark"` ProxyTrafficClass int `json:"proxyTrafficClass"` - WgEndpoint conn.Addr `json:"wgEndpoint"` + WgEndpointNetwork string `json:"wgEndpointNetwork"` + WgEndpointAddress conn.Addr `json:"wgEndpoint"` WgConnListenNetwork string `json:"wgConnListenNetwork"` WgConnListenAddress string `json:"wgConnListenAddress"` WgFwmark int `json:"wgFwmark"` @@ -83,6 +84,7 @@ type server struct { maxProxyPacketSizev6 int wgTunnelMTUv4 int wgTunnelMTUv6 int + wgNetwork string wgAddr conn.Addr handler packet.Handler logger *zap.Logger @@ -114,6 +116,15 @@ func (sc *ServerConfig) Server(logger *zap.Logger, listenConfigCache conn.Listen return nil, fmt.Errorf("invalid proxyListenNetwork: %s", sc.ProxyListenNetwork) } + // Check WgEndpointNetwork. + switch sc.WgEndpointNetwork { + case "": + sc.WgEndpointNetwork = "ip" + case "ip", "ip4", "ip6": + default: + return nil, fmt.Errorf("invalid wgEndpointNetwork: %s", sc.WgEndpointNetwork) + } + // Check WgConnListenNetwork. switch sc.WgConnListenNetwork { case "": @@ -153,7 +164,8 @@ func (sc *ServerConfig) Server(logger *zap.Logger, listenConfigCache conn.Listen maxProxyPacketSizev6: maxProxyPacketSizev6, wgTunnelMTUv4: wgTunnelMTUv4, wgTunnelMTUv6: wgTunnelMTUv6, - wgAddr: sc.WgEndpoint, + wgNetwork: sc.WgEndpointNetwork, + wgAddr: sc.WgEndpointAddress, handler: handler, logger: logger, proxyConnListenConfig: listenConfigCache.Get(conn.ListenerSocketOptions{ @@ -332,7 +344,7 @@ func (s *server) recvFromProxyConnGeneric(ctx context.Context, proxyConn *net.UD s.wg.Done() }() - wgAddrPort, err := s.wgAddr.ResolveIPPort(ctx) + wgAddrPort, err := s.wgAddr.ResolveIPPort(ctx, s.wgNetwork) if err != nil { s.logger.Warn("Failed to resolve wg address for new session", zap.String("server", s.name), diff --git a/service/server_mmsg.go b/service/server_mmsg.go index 788cd9f..ef4088e 100644 --- a/service/server_mmsg.go +++ b/service/server_mmsg.go @@ -240,7 +240,7 @@ func (s *server) recvFromProxyConnRecvmmsg(ctx context.Context, proxyConn *conn. s.wg.Done() }() - wgAddrPort, err := s.wgAddr.ResolveIPPort(ctx) + wgAddrPort, err := s.wgAddr.ResolveIPPort(ctx, s.wgNetwork) if err != nil { s.logger.Warn("Failed to resolve wgAddr", zap.String("server", s.name),