Skip to content

Commit

Permalink
Use the IP PacketConn to specify the local proxy IP
Browse files Browse the repository at this point in the history
This allows the proxy to determine the destination address of incoming
UDP packets and specify the source address of outgoing UDP packets.  It
requires duplicating the UDP service because it uses the x/net/ipv[4,6]
packages, which require us to know the address family of each incoming
packet before it is received.
  • Loading branch information
Ben Schwartz committed Jan 5, 2022
1 parent 77b265e commit e73f380
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 35 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
github.com/shadowsocks/go-shadowsocks2 v0.1.4-0.20201002022019-75d43273f5a5
github.com/stretchr/testify v1.6.1
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897
golang.org/x/net v0.0.0-20211216030914-fe4d6282115f // indirect
gopkg.in/yaml.v2 v2.3.0
)

Expand Down
8 changes: 8 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20211216030914-fe4d6282115f h1:hEYJvxw1lSnWIl8X9ofsYMklzaDs90JI2az5YMd4fPM=
golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
Expand All @@ -121,7 +123,13 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 h1:AvbQYmiaaaza3cW3QXRyPo5kYgpFIzOAfeAAN7m3qQ4=
golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
Expand Down
8 changes: 4 additions & 4 deletions integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func startTCPEchoServer(t testing.TB) (*net.TCPListener, *sync.WaitGroup) {
}

func startUDPEchoServer(t testing.TB) (*net.UDPConn, *sync.WaitGroup) {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
t.Fatalf("Proxy ListenUDP failed: %v", err)
}
Expand Down Expand Up @@ -256,7 +256,7 @@ func (m *fakeUDPMetrics) RemoveUDPNatEntry() {
func TestUDPEcho(t *testing.T) {
echoConn, echoRunning := startUDPEchoServer(t)

proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
proxyConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
t.Fatalf("ListenTCP failed: %v", err)
}
Expand Down Expand Up @@ -496,7 +496,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) {
func BenchmarkUDPEcho(b *testing.B) {
echoConn, echoRunning := startUDPEchoServer(b)

proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
proxyConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
b.Fatalf("ListenTCP failed: %v", err)
}
Expand Down Expand Up @@ -544,7 +544,7 @@ func BenchmarkUDPEcho(b *testing.B) {
func BenchmarkUDPManyKeys(b *testing.B) {
echoConn, echoRunning := startUDPEchoServer(b)

proxyConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
proxyConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
b.Fatalf("ListenTCP failed: %v", err)
}
Expand Down
66 changes: 66 additions & 0 deletions net/net.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package net

import (
"errors"
"fmt"
"io"
"net"
"runtime"

"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)

// DuplexConn is a net.Conn that allows for closing only the reader or writer end of
Expand Down Expand Up @@ -97,3 +103,63 @@ type ConnectionError struct {
func NewConnectionError(status, message string, cause error) *ConnectionError {
return &ConnectionError{Status: status, Message: message, Cause: cause}
}

// ReadFromWithDst reads one packet from `conn` into `b` and returns the number
// of bytes read, the source address, and the destination IP address. It enables
// recovery of the destination IP, which is otherwise lost for UDP connections
// that are bound to `0.0.0.0` or `::`.
func ReadFromWithDst(conn net.PacketConn, b []byte) (n int, src *net.UDPAddr, dst net.IP, err error) {
var tmpSrc net.Addr
if conn.LocalAddr().Network() == "udp4" {
ipv4Conn := ipv4.NewPacketConn(conn)
if err = ipv4Conn.SetControlMessage(ipv4.FlagDst, true); err != nil {
return
}
var cm *ipv4.ControlMessage
if n, cm, tmpSrc, err = ipv4Conn.ReadFrom(b); err != nil {
return
}
if cm != nil {
dst = cm.Dst
} else if runtime.GOOS != "windows" {
err = errors.New("control data is missing")
return
}
} else if conn.LocalAddr().Network() == "udp6" {
ipv6Conn := ipv6.NewPacketConn(conn)
if err = ipv6Conn.SetControlMessage(ipv6.FlagDst, true); err != nil {
return
}
var cm *ipv6.ControlMessage
if n, cm, tmpSrc, err = ipv6Conn.ReadFrom(b); err != nil {
return
}
if cm != nil {
dst = cm.Dst
} else if runtime.GOOS != "windows" {
err = errors.New("control data is missing")
return
}
} else {
err = fmt.Errorf("unsupported network: %s", conn.LocalAddr().Network())
return
}
src = tmpSrc.(*net.UDPAddr)
return
}

// WriteToWithSrc sends `b` to `dst` on `conn` from the specified source IP.
// This can be useful when the system has multiple IP addresses of the same family.
// Similar functionality can be achieved by binding a new UDP socket to a specific local address,
// but that might run into problems if the port is already bound by an existing socket.
func WriteToWithSrc(conn net.PacketConn, b []byte, src net.IP, dst *net.UDPAddr) (int, error) {
if conn.LocalAddr().Network() == "udp4" {
cm := &ipv4.ControlMessage{Src: src}
return ipv4.NewPacketConn(conn).WriteTo(b, cm, dst)
} else if conn.LocalAddr().Network() == "udp6" {
cm := &ipv6.ControlMessage{Src: src}
return ipv6.NewPacketConn(conn).WriteTo(b, cm, dst)
} else {
return 0, fmt.Errorf("unsupported network: %s", conn.LocalAddr().Network())
}
}
37 changes: 27 additions & 10 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ func init() {

type ssPort struct {
tcpService service.TCPService
udpService service.UDPService
udp4Service service.UDPService
udp6Service service.UDPService
cipherList service.CipherList
}

Expand All @@ -79,18 +80,25 @@ func (s *SSServer) startPort(portNum int) error {
if err != nil {
return fmt.Errorf("Failed to start TCP on port %v: %v", portNum, err)
}
packetConn, err := net.ListenUDP("udp", &net.UDPAddr{Port: portNum})
if err != nil {
return fmt.Errorf("Failed to start UDP on port %v: %v", portNum, err)
udp4Conn, udp4err := net.ListenUDP("udp4", &net.UDPAddr{Port: portNum})
udp6Conn, udp6err := net.ListenUDP("udp6", &net.UDPAddr{Port: portNum})
if udp4err != nil && udp6err != nil {
return fmt.Errorf("Failed to start UDP on port %v: %v", portNum, udp4err)
}
logger.Infof("Listening TCP and UDP on port %v", portNum)
port := &ssPort{cipherList: service.NewCipherList()}
s.ports[portNum] = port
// TODO: Register initial data metrics at zero.
port.tcpService = service.NewTCPService(port.cipherList, &s.replayCache, s.m, tcpReadTimeout)
port.udpService = service.NewUDPService(s.natTimeout, port.cipherList, s.m)
s.ports[portNum] = port
go port.tcpService.Serve(listener)
go port.udpService.Serve(packetConn)
if udp4err == nil {
port.udp4Service = service.NewUDPService(s.natTimeout, port.cipherList, s.m)
go port.udp4Service.Serve(udp4Conn)
}
if udp6err == nil {
port.udp6Service = service.NewUDPService(s.natTimeout, port.cipherList, s.m)
go port.udp6Service.Serve(udp6Conn)
}
return nil
}

Expand All @@ -100,13 +108,22 @@ func (s *SSServer) removePort(portNum int) error {
return fmt.Errorf("Port %v doesn't exist", portNum)
}
tcpErr := port.tcpService.Stop()
udpErr := port.udpService.Stop()
var udp4Err, udp6Err error
if port.udp4Service != nil {
udp4Err = port.udp4Service.Stop()
}
if port.udp6Service != nil {
udp6Err = port.udp6Service.Stop()
}
delete(s.ports, portNum)
if tcpErr != nil {
return fmt.Errorf("Failed to close listener on %v: %v", portNum, tcpErr)
}
if udpErr != nil {
return fmt.Errorf("Failed to close packetConn on %v: %v", portNum, udpErr)
if udp4Err != nil {
return fmt.Errorf("Failed to stop IPv4 UDP service on %v: %v", portNum, udp4Err)
}
if udp4Err != nil {
return fmt.Errorf("Failed to stop IPv6 UDP service on %v: %v", portNum, udp6Err)
}
logger.Infof("Stopped TCP and UDP on port %v", portNum)
return nil
Expand Down
42 changes: 26 additions & 16 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error {
}()

// Attempt to read an upstream packet.
clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf)
clientProxyBytes, clientAddr, proxyIP, err := onet.ReadFromWithDst(clientConn, cipherBuf)
if err != nil {
s.mu.RLock()
stopped = s.stopped
Expand Down Expand Up @@ -171,7 +171,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error {
cipherData := cipherBuf[:clientProxyBytes]
var payload []byte
var tgtUDPAddr *net.UDPAddr
targetConn := nm.Get(clientAddr.String())
targetConn := nm.Get(clientAddr, proxyIP)
if targetConn == nil {
var locErr error
clientLocation, locErr = s.m.GetLocation(clientAddr)
Expand All @@ -180,7 +180,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error {
}
debugUDPAddr(clientAddr, "Got location \"%s\"", clientLocation)

ip := clientAddr.(*net.UDPAddr).IP
ip := clientAddr.IP
var textData []byte
var cipher *ss.Cipher
unpackStart := time.Now()
Expand All @@ -200,7 +200,7 @@ func (s *udpService) Serve(clientConn net.PacketConn) error {
if err != nil {
return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err)
}
targetConn = nm.Add(clientAddr, clientConn, cipher, udpConn, clientLocation, keyID)
targetConn = nm.Add(clientAddr, proxyIP, clientConn, cipher, udpConn, clientLocation, keyID)
} else {
clientLocation = targetConn.clientLocation

Expand Down Expand Up @@ -335,29 +335,38 @@ func (c *natconn) ReadFrom(buf []byte) (int, net.Addr, error) {
return n, addr, err
}

type natkey struct {
clientAddr string // TODO: Use netip.AddrPort
proxyIP string // TODO: Use netip.Addr
}

func makeNATKey(clientAddr *net.UDPAddr, proxyIP net.IP) natkey {
return natkey{clientAddr.String(), proxyIP.String()}
}

// Packet NAT table
type natmap struct {
sync.RWMutex
keyConn map[string]*natconn
keyConn map[natkey]*natconn
timeout time.Duration
metrics metrics.ShadowsocksMetrics
running *sync.WaitGroup
}

func newNATmap(timeout time.Duration, sm metrics.ShadowsocksMetrics, running *sync.WaitGroup) *natmap {
m := &natmap{metrics: sm, running: running}
m.keyConn = make(map[string]*natconn)
m.keyConn = make(map[natkey]*natconn)
m.timeout = timeout
return m
}

func (m *natmap) Get(key string) *natconn {
func (m *natmap) Get(clientAddr *net.UDPAddr, proxyIP net.IP) *natconn {
m.RLock()
defer m.RUnlock()
return m.keyConn[key]
return m.keyConn[makeNATKey(clientAddr, proxyIP)]
}

func (m *natmap) set(key string, pc net.PacketConn, cipher *ss.Cipher, keyID, clientLocation string) *natconn {
func (m *natmap) set(key natkey, pc net.PacketConn, cipher *ss.Cipher, keyID, clientLocation string) *natconn {
entry := &natconn{
PacketConn: pc,
cipher: cipher,
Expand All @@ -373,7 +382,7 @@ func (m *natmap) set(key string, pc net.PacketConn, cipher *ss.Cipher, keyID, cl
return entry
}

func (m *natmap) del(key string) net.PacketConn {
func (m *natmap) del(key natkey) net.PacketConn {
m.Lock()
defer m.Unlock()

Expand All @@ -385,15 +394,16 @@ func (m *natmap) del(key string) net.PacketConn {
return nil
}

func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cipher *ss.Cipher, targetConn net.PacketConn, clientLocation, keyID string) *natconn {
entry := m.set(clientAddr.String(), targetConn, cipher, keyID, clientLocation)
func (m *natmap) Add(clientAddr *net.UDPAddr, proxyIP net.IP, clientConn net.PacketConn, cipher *ss.Cipher, targetConn net.PacketConn, clientLocation, keyID string) *natconn {
key := makeNATKey(clientAddr, proxyIP)
entry := m.set(key, targetConn, cipher, keyID, clientLocation)

m.metrics.AddUDPNatEntry()
m.running.Add(1)
go func() {
timedCopy(clientAddr, clientConn, entry, keyID, m.metrics)
timedCopy(clientAddr, proxyIP, clientConn, entry, keyID, m.metrics)
m.metrics.RemoveUDPNatEntry()
if pc := m.del(clientAddr.String()); pc != nil {
if pc := m.del(key); pc != nil {
pc.Close()
}
m.running.Done()
Expand All @@ -420,7 +430,7 @@ func (m *natmap) Close() error {
var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345"))

// copy from target to client until read timeout
func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn,
func timedCopy(clientAddr *net.UDPAddr, proxyIP net.IP, clientConn net.PacketConn, targetConn *natconn,
keyID string, sm metrics.ShadowsocksMetrics) {
// pkt is used for in-place encryption of downstream UDP packets, with the layout
// [padding?][salt][address][body][tag][extra]
Expand Down Expand Up @@ -475,7 +485,7 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco
if err != nil {
return onet.NewConnectionError("ERR_PACK", "Failed to pack data to client", err)
}
proxyClientBytes, err = clientConn.WriteTo(buf, clientAddr)
proxyClientBytes, err = onet.WriteToWithSrc(clientConn, buf, proxyIP, clientAddr)
if err != nil {
return onet.NewConnectionError("ERR_WRITE", "Failed to write to client", err)
}
Expand Down
11 changes: 6 additions & 5 deletions service/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const timeout = 5 * time.Minute
var clientAddr = net.UDPAddr{IP: []byte{192, 0, 2, 1}, Port: 12345}
var targetAddr = net.UDPAddr{IP: []byte{192, 0, 2, 2}, Port: 54321}
var dnsAddr = net.UDPAddr{IP: []byte{192, 0, 2, 3}, Port: 53}
var proxyIP net.IP = []byte{192,0,2, 4}
var natCipher *ss.Cipher

func init() {
Expand Down Expand Up @@ -203,7 +204,7 @@ func assertAlmostEqual(t *testing.T, a, b time.Time) {

func TestNATEmpty(t *testing.T) {
nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{})
if nat.Get("foo") != nil {
if nat.Get(&clientAddr, proxyIP) != nil {
t.Error("Expected nil value from empty NAT map")
}
}
Expand All @@ -212,8 +213,8 @@ func setupNAT() (*fakePacketConn, *fakePacketConn, *natconn) {
nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{})
clientConn := makePacketConn()
targetConn := makePacketConn()
nat.Add(&clientAddr, clientConn, natCipher, targetConn, "ZZ", "key id")
entry := nat.Get(clientAddr.String())
nat.Add(&clientAddr, proxyIP, clientConn, natCipher, targetConn, "ZZ", "key id")
entry := nat.Get(&clientAddr, proxyIP)
return clientConn, targetConn, entry
}

Expand Down Expand Up @@ -478,7 +479,7 @@ func TestUDPDoubleServe(t *testing.T) {

c := make(chan error)
for i := 0; i < 2; i++ {
clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
clientConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
t.Fatalf("ListenUDP failed: %v", err)
}
Expand Down Expand Up @@ -513,7 +514,7 @@ func TestUDPEarlyStop(t *testing.T) {
if err := s.Stop(); err != nil {
t.Error(err)
}
clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
clientConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
t.Fatalf("ListenUDP failed: %v", err)
}
Expand Down

0 comments on commit e73f380

Please sign in to comment.