From 022d8e8c97bb7b933739f11b510d13dbfde06a57 Mon Sep 17 00:00:00 2001 From: Saber Haj Rabiee Date: Tue, 27 Aug 2024 09:16:06 -0700 Subject: [PATCH] fix: make creating tcp and udp handles flexible --- cmd/outline-ss-server/main.go | 10 +++++-- service/tcp.go | 10 +++---- service/udp.go | 49 ++++++++++++++++++++++------------- 3 files changed, 44 insertions(+), 25 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 42a203ce..39e6318b 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -34,6 +34,7 @@ import ( "gopkg.in/yaml.v2" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" + onet "github.com/Jigsaw-Code/outline-ss-server/net" "github.com/Jigsaw-Code/outline-ss-server/service" ) @@ -87,8 +88,13 @@ func (s *SSServer) startPort(portNum int, fwmark uint) error { port := &ssPort{tcpListener: listener, packetConn: packetConn, cipherList: service.NewCipherList()} authFunc := service.NewShadowsocksStreamAuthenticator(port.cipherList, &s.replayCache, s.m) // TODO: Register initial data metrics at zero. - tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout, fwmark) - packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m, fwmark) + + tcpDialer := service.MakeValidatingTCPStreamDialer(onet.RequirePublicIP, fwmark) + tcpHandler := service.NewTCPHandler(authFunc, s.m, tcpReadTimeout, tcpDialer) + udpPacketDialer := func() (net.PacketConn, *onet.ConnectionError) { + return service.MakeTargetPacketConnection(fwmark) + } + packetHandler := service.NewPacketHandler(s.natTimeout, port.cipherList, s.m, udpPacketDialer) s.ports[portNum] = port go service.StreamServe(service.WrapStreamListener(listener.AcceptTCP), tcpHandler.Handle) go packetHandler.Handle(port.packetConn) diff --git a/service/tcp.go b/service/tcp.go index b0640aa3..10562e3b 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -172,18 +172,18 @@ type tcpHandler struct { } // NewTCPService creates a TCPService -func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration, fwmark uint) TCPHandler { - defaultDialer := makeValidatingTCPStreamDialer(onet.RequirePublicIP, fwmark) - +func NewTCPHandler(authenticate StreamAuthenticateFunc, m TCPMetrics, timeout time.Duration, dialer transport.StreamDialer) TCPHandler { return &tcpHandler{ m: m, readTimeout: timeout, authenticate: authenticate, - dialer: defaultDialer, + dialer: dialer, } } -func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator, fwmark uint) transport.StreamDialer { +// fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management. +// Value of 0 disables fwmark (SO_MARK) +func MakeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator, fwmark uint) transport.StreamDialer { return &transport.TCPDialer{Dialer: net.Dialer{Control: func(network, address string, c syscall.RawConn) error { if fwmark > 0 { err := c.Control(func(fd uintptr) { diff --git a/service/udp.go b/service/udp.go index 8550ba6e..2eaebd3e 100644 --- a/service/udp.go +++ b/service/udp.go @@ -86,21 +86,24 @@ func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherLis return nil, "", nil, errors.New("could not find valid UDP cipher") } +// Type alias for creating UDP sockets +type UDPConnFactory = func() (net.PacketConn, *onet.ConnectionError) + type packetHandler struct { natTimeout time.Duration ciphers CipherList m UDPMetrics targetIPValidator onet.TargetIPValidator - fwmark uint + udpConnFactory UDPConnFactory } // NewPacketHandler creates a UDPService -func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetrics, fwmark uint) PacketHandler { +func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetrics, udpConnFactory UDPConnFactory) PacketHandler { return &packetHandler{ natTimeout: natTimeout, ciphers: cipherList, m: m, targetIPValidator: onet.RequirePublicIP, - fwmark: fwmark, + udpConnFactory: udpConnFactory, } } @@ -116,6 +119,29 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali h.targetIPValidator = targetIPValidator } +// fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management. +// Value of 0 disables fwmark (SO_MARK) +func MakeTargetPacketConnection(fwmark uint) (net.PacketConn, *onet.ConnectionError) { + udpConn, err := net.ListenPacket("udp", "") + if err != nil { + return nil, onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) + } + + if fwmark > 0 { + file, err := udpConn.(*net.UDPConn).File() + if err != nil { + return nil, onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to get UDP socket file", err) + } + defer file.Close() + + err = syscall.SetsockoptInt(int(file.Fd()), syscall.SOL_SOCKET, syscall.SO_MARK, int(fwmark)) + if err != nil { + slog.Error("Set fwmark failed.", "err", os.NewSyscallError("failed to set mark for UDP socket", err)) + } + } + return udpConn, nil +} + // Listen on addr for encrypted packets and basically do UDP NAT. // We take the ciphers as a pointer because it gets replaced on config updates. func (h *packetHandler) Handle(clientConn net.PacketConn) { @@ -180,22 +206,9 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { return onetErr } - udpConn, err := net.ListenPacket("udp", "") + udpConn, err := h.udpConnFactory() if err != nil { - return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) - } - - if h.fwmark > 0 { - file, err := udpConn.(*net.UDPConn).File() - if err != nil { - return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to get UDP socket file", err) - } - defer file.Close() - - err = syscall.SetsockoptInt(int(file.Fd()), syscall.SOL_SOCKET, syscall.SO_MARK, int(h.fwmark)) - if err != nil { - slog.Error("Set fwmark failed.", "err", os.NewSyscallError("failed to set mark for UDP socket", err)) - } + return err } targetConn = nm.Add(clientAddr, clientConn, cryptoKey, udpConn, clientInfo, keyID)