From e73f3804ec4f5b963a8d494e5486bcd8a85dd960 Mon Sep 17 00:00:00 2001 From: Ben Schwartz Date: Wed, 5 Jan 2022 09:04:54 -0500 Subject: [PATCH] Use the IP PacketConn to specify the local proxy IP This allows the proxy to determine the destination address of incoming UDP packets and specify the source address of outgoing UDP packets. It requires duplicating the UDP service because it uses the x/net/ipv[4,6] packages, which require us to know the address family of each incoming packet before it is received. --- go.mod | 1 + go.sum | 8 ++++ integration_test/integration_test.go | 8 ++-- net/net.go | 66 ++++++++++++++++++++++++++++ server.go | 37 +++++++++++----- service/udp.go | 42 +++++++++++------- service/udp_test.go | 11 ++--- 7 files changed, 138 insertions(+), 35 deletions(-) diff --git a/go.mod b/go.mod index b22e84ff..cb207d0f 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/shadowsocks/go-shadowsocks2 v0.1.4-0.20201002022019-75d43273f5a5 github.com/stretchr/testify v1.6.1 golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897 + golang.org/x/net v0.0.0-20211216030914-fe4d6282115f // indirect gopkg.in/yaml.v2 v2.3.0 ) diff --git a/go.sum b/go.sum index e0d24d73..4f597537 100644 --- a/go.sum +++ b/go.sum @@ -104,6 +104,8 @@ golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20211216030914-fe4d6282115f h1:hEYJvxw1lSnWIl8X9ofsYMklzaDs90JI2az5YMd4fPM= +golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -121,7 +123,13 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 h1:AvbQYmiaaaza3cW3QXRyPo5kYgpFIzOAfeAAN7m3qQ4= golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/integration_test/integration_test.go b/integration_test/integration_test.go index de674e7e..d4955d52 100644 --- a/integration_test/integration_test.go +++ b/integration_test/integration_test.go @@ -72,7 +72,7 @@ func startTCPEchoServer(t testing.TB) (*net.TCPListener, *sync.WaitGroup) { } func startUDPEchoServer(t testing.TB) (*net.UDPConn, *sync.WaitGroup) { - conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) if err != nil { t.Fatalf("Proxy ListenUDP failed: %v", err) } @@ -256,7 +256,7 @@ func (m *fakeUDPMetrics) RemoveUDPNatEntry() { func TestUDPEcho(t *testing.T) { echoConn, echoRunning := startUDPEchoServer(t) - proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + proxyConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) if err != nil { t.Fatalf("ListenTCP failed: %v", err) } @@ -496,7 +496,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) { func BenchmarkUDPEcho(b *testing.B) { echoConn, echoRunning := startUDPEchoServer(b) - proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + proxyConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) if err != nil { b.Fatalf("ListenTCP failed: %v", err) } @@ -544,7 +544,7 @@ func BenchmarkUDPEcho(b *testing.B) { func BenchmarkUDPManyKeys(b *testing.B) { echoConn, echoRunning := startUDPEchoServer(b) - proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + proxyConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) if err != nil { b.Fatalf("ListenTCP failed: %v", err) } diff --git a/net/net.go b/net/net.go index 1965245d..7a76091c 100644 --- a/net/net.go +++ b/net/net.go @@ -1,8 +1,14 @@ package net import ( + "errors" + "fmt" "io" "net" + "runtime" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) // DuplexConn is a net.Conn that allows for closing only the reader or writer end of @@ -97,3 +103,63 @@ type ConnectionError struct { func NewConnectionError(status, message string, cause error) *ConnectionError { return &ConnectionError{Status: status, Message: message, Cause: cause} } + +// ReadFromWithDst reads one packet from `conn` into `b` and returns the number +// of bytes read, the source address, and the destination IP address. It enables +// recovery of the destination IP, which is otherwise lost for UDP connections +// that are bound to `0.0.0.0` or `::`. +func ReadFromWithDst(conn net.PacketConn, b []byte) (n int, src *net.UDPAddr, dst net.IP, err error) { + var tmpSrc net.Addr + if conn.LocalAddr().Network() == "udp4" { + ipv4Conn := ipv4.NewPacketConn(conn) + if err = ipv4Conn.SetControlMessage(ipv4.FlagDst, true); err != nil { + return + } + var cm *ipv4.ControlMessage + if n, cm, tmpSrc, err = ipv4Conn.ReadFrom(b); err != nil { + return + } + if cm != nil { + dst = cm.Dst + } else if runtime.GOOS != "windows" { + err = errors.New("control data is missing") + return + } + } else if conn.LocalAddr().Network() == "udp6" { + ipv6Conn := ipv6.NewPacketConn(conn) + if err = ipv6Conn.SetControlMessage(ipv6.FlagDst, true); err != nil { + return + } + var cm *ipv6.ControlMessage + if n, cm, tmpSrc, err = ipv6Conn.ReadFrom(b); err != nil { + return + } + if cm != nil { + dst = cm.Dst + } else if runtime.GOOS != "windows" { + err = errors.New("control data is missing") + return + } + } else { + err = fmt.Errorf("unsupported network: %s", conn.LocalAddr().Network()) + return + } + src = tmpSrc.(*net.UDPAddr) + return +} + +// WriteToWithSrc sends `b` to `dst` on `conn` from the specified source IP. +// This can be useful when the system has multiple IP addresses of the same family. +// Similar functionality can be achieved by binding a new UDP socket to a specific local address, +// but that might run into problems if the port is already bound by an existing socket. +func WriteToWithSrc(conn net.PacketConn, b []byte, src net.IP, dst *net.UDPAddr) (int, error) { + if conn.LocalAddr().Network() == "udp4" { + cm := &ipv4.ControlMessage{Src: src} + return ipv4.NewPacketConn(conn).WriteTo(b, cm, dst) + } else if conn.LocalAddr().Network() == "udp6" { + cm := &ipv6.ControlMessage{Src: src} + return ipv6.NewPacketConn(conn).WriteTo(b, cm, dst) + } else { + return 0, fmt.Errorf("unsupported network: %s", conn.LocalAddr().Network()) + } +} \ No newline at end of file diff --git a/server.go b/server.go index c3c04121..04ecd78e 100644 --- a/server.go +++ b/server.go @@ -63,7 +63,8 @@ func init() { type ssPort struct { tcpService service.TCPService - udpService service.UDPService + udp4Service service.UDPService + udp6Service service.UDPService cipherList service.CipherList } @@ -79,18 +80,25 @@ func (s *SSServer) startPort(portNum int) error { if err != nil { return fmt.Errorf("Failed to start TCP on port %v: %v", portNum, err) } - packetConn, err := net.ListenUDP("udp", &net.UDPAddr{Port: portNum}) - if err != nil { - return fmt.Errorf("Failed to start UDP on port %v: %v", portNum, err) + udp4Conn, udp4err := net.ListenUDP("udp4", &net.UDPAddr{Port: portNum}) + udp6Conn, udp6err := net.ListenUDP("udp6", &net.UDPAddr{Port: portNum}) + if udp4err != nil && udp6err != nil { + return fmt.Errorf("Failed to start UDP on port %v: %v", portNum, udp4err) } logger.Infof("Listening TCP and UDP on port %v", portNum) port := &ssPort{cipherList: service.NewCipherList()} + s.ports[portNum] = port // TODO: Register initial data metrics at zero. port.tcpService = service.NewTCPService(port.cipherList, &s.replayCache, s.m, tcpReadTimeout) - port.udpService = service.NewUDPService(s.natTimeout, port.cipherList, s.m) - s.ports[portNum] = port go port.tcpService.Serve(listener) - go port.udpService.Serve(packetConn) + if udp4err == nil { + port.udp4Service = service.NewUDPService(s.natTimeout, port.cipherList, s.m) + go port.udp4Service.Serve(udp4Conn) + } + if udp6err == nil { + port.udp6Service = service.NewUDPService(s.natTimeout, port.cipherList, s.m) + go port.udp6Service.Serve(udp6Conn) + } return nil } @@ -100,13 +108,22 @@ func (s *SSServer) removePort(portNum int) error { return fmt.Errorf("Port %v doesn't exist", portNum) } tcpErr := port.tcpService.Stop() - udpErr := port.udpService.Stop() + var udp4Err, udp6Err error + if port.udp4Service != nil { + udp4Err = port.udp4Service.Stop() + } + if port.udp6Service != nil { + udp6Err = port.udp6Service.Stop() + } delete(s.ports, portNum) if tcpErr != nil { return fmt.Errorf("Failed to close listener on %v: %v", portNum, tcpErr) } - if udpErr != nil { - return fmt.Errorf("Failed to close packetConn on %v: %v", portNum, udpErr) + if udp4Err != nil { + return fmt.Errorf("Failed to stop IPv4 UDP service on %v: %v", portNum, udp4Err) + } + if udp4Err != nil { + return fmt.Errorf("Failed to stop IPv6 UDP service on %v: %v", portNum, udp6Err) } logger.Infof("Stopped TCP and UDP on port %v", portNum) return nil diff --git a/service/udp.go b/service/udp.go index 42160497..0c9d9d65 100644 --- a/service/udp.go +++ b/service/udp.go @@ -135,7 +135,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { }() // Attempt to read an upstream packet. - clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf) + clientProxyBytes, clientAddr, proxyIP, err := onet.ReadFromWithDst(clientConn, cipherBuf) if err != nil { s.mu.RLock() stopped = s.stopped @@ -171,7 +171,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { cipherData := cipherBuf[:clientProxyBytes] var payload []byte var tgtUDPAddr *net.UDPAddr - targetConn := nm.Get(clientAddr.String()) + targetConn := nm.Get(clientAddr, proxyIP) if targetConn == nil { var locErr error clientLocation, locErr = s.m.GetLocation(clientAddr) @@ -180,7 +180,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { } debugUDPAddr(clientAddr, "Got location \"%s\"", clientLocation) - ip := clientAddr.(*net.UDPAddr).IP + ip := clientAddr.IP var textData []byte var cipher *ss.Cipher unpackStart := time.Now() @@ -200,7 +200,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error { if err != nil { return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) } - targetConn = nm.Add(clientAddr, clientConn, cipher, udpConn, clientLocation, keyID) + targetConn = nm.Add(clientAddr, proxyIP, clientConn, cipher, udpConn, clientLocation, keyID) } else { clientLocation = targetConn.clientLocation @@ -335,10 +335,19 @@ func (c *natconn) ReadFrom(buf []byte) (int, net.Addr, error) { return n, addr, err } +type natkey struct { + clientAddr string // TODO: Use netip.AddrPort + proxyIP string // TODO: Use netip.Addr +} + +func makeNATKey(clientAddr *net.UDPAddr, proxyIP net.IP) natkey { + return natkey{clientAddr.String(), proxyIP.String()} +} + // Packet NAT table type natmap struct { sync.RWMutex - keyConn map[string]*natconn + keyConn map[natkey]*natconn timeout time.Duration metrics metrics.ShadowsocksMetrics running *sync.WaitGroup @@ -346,18 +355,18 @@ type natmap struct { func newNATmap(timeout time.Duration, sm metrics.ShadowsocksMetrics, running *sync.WaitGroup) *natmap { m := &natmap{metrics: sm, running: running} - m.keyConn = make(map[string]*natconn) + m.keyConn = make(map[natkey]*natconn) m.timeout = timeout return m } -func (m *natmap) Get(key string) *natconn { +func (m *natmap) Get(clientAddr *net.UDPAddr, proxyIP net.IP) *natconn { m.RLock() defer m.RUnlock() - return m.keyConn[key] + return m.keyConn[makeNATKey(clientAddr, proxyIP)] } -func (m *natmap) set(key string, pc net.PacketConn, cipher *ss.Cipher, keyID, clientLocation string) *natconn { +func (m *natmap) set(key natkey, pc net.PacketConn, cipher *ss.Cipher, keyID, clientLocation string) *natconn { entry := &natconn{ PacketConn: pc, cipher: cipher, @@ -373,7 +382,7 @@ func (m *natmap) set(key string, pc net.PacketConn, cipher *ss.Cipher, keyID, cl return entry } -func (m *natmap) del(key string) net.PacketConn { +func (m *natmap) del(key natkey) net.PacketConn { m.Lock() defer m.Unlock() @@ -385,15 +394,16 @@ func (m *natmap) del(key string) net.PacketConn { return nil } -func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cipher *ss.Cipher, targetConn net.PacketConn, clientLocation, keyID string) *natconn { - entry := m.set(clientAddr.String(), targetConn, cipher, keyID, clientLocation) +func (m *natmap) Add(clientAddr *net.UDPAddr, proxyIP net.IP, clientConn net.PacketConn, cipher *ss.Cipher, targetConn net.PacketConn, clientLocation, keyID string) *natconn { + key := makeNATKey(clientAddr, proxyIP) + entry := m.set(key, targetConn, cipher, keyID, clientLocation) m.metrics.AddUDPNatEntry() m.running.Add(1) go func() { - timedCopy(clientAddr, clientConn, entry, keyID, m.metrics) + timedCopy(clientAddr, proxyIP, clientConn, entry, keyID, m.metrics) m.metrics.RemoveUDPNatEntry() - if pc := m.del(clientAddr.String()); pc != nil { + if pc := m.del(key); pc != nil { pc.Close() } m.running.Done() @@ -420,7 +430,7 @@ func (m *natmap) Close() error { var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) // copy from target to client until read timeout -func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, +func timedCopy(clientAddr *net.UDPAddr, proxyIP net.IP, clientConn net.PacketConn, targetConn *natconn, keyID string, sm metrics.ShadowsocksMetrics) { // pkt is used for in-place encryption of downstream UDP packets, with the layout // [padding?][salt][address][body][tag][extra] @@ -475,7 +485,7 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco if err != nil { return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err) } - proxyClientBytes, err = clientConn.WriteTo(buf, clientAddr) + proxyClientBytes, err = onet.WriteToWithSrc(clientConn, buf, proxyIP, clientAddr) if err != nil { return onet.NewConnectionError("ERR_WRITE", "Failed to write to client", err) } diff --git a/service/udp_test.go b/service/udp_test.go index 12139fdc..51675361 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -35,6 +35,7 @@ const timeout = 5 * time.Minute var clientAddr = net.UDPAddr{IP: []byte{192, 0, 2, 1}, Port: 12345} var targetAddr = net.UDPAddr{IP: []byte{192, 0, 2, 2}, Port: 54321} var dnsAddr = net.UDPAddr{IP: []byte{192, 0, 2, 3}, Port: 53} +var proxyIP net.IP = []byte{192,0,2, 4} var natCipher *ss.Cipher func init() { @@ -203,7 +204,7 @@ func assertAlmostEqual(t *testing.T, a, b time.Time) { func TestNATEmpty(t *testing.T) { nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}) - if nat.Get("foo") != nil { + if nat.Get(&clientAddr, proxyIP) != nil { t.Error("Expected nil value from empty NAT map") } } @@ -212,8 +213,8 @@ func setupNAT() (*fakePacketConn, *fakePacketConn, *natconn) { nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}) clientConn := makePacketConn() targetConn := makePacketConn() - nat.Add(&clientAddr, clientConn, natCipher, targetConn, "ZZ", "key id") - entry := nat.Get(clientAddr.String()) + nat.Add(&clientAddr, proxyIP, clientConn, natCipher, targetConn, "ZZ", "key id") + entry := nat.Get(&clientAddr, proxyIP) return clientConn, targetConn, entry } @@ -478,7 +479,7 @@ func TestUDPDoubleServe(t *testing.T) { c := make(chan error) for i := 0; i < 2; i++ { - clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + clientConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) if err != nil { t.Fatalf("ListenUDP failed: %v", err) } @@ -513,7 +514,7 @@ func TestUDPEarlyStop(t *testing.T) { if err := s.Stop(); err != nil { t.Error(err) } - clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + clientConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) if err != nil { t.Fatalf("ListenUDP failed: %v", err) }