diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 4a38ad40..740fdec7 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -60,7 +60,7 @@ func init() { } type SSServer struct { - stopConfig func() + stopConfig func() error lnManager service.ListenerManager natTimeout time.Duration m *outlineMetrics @@ -83,12 +83,14 @@ func (s *SSServer) loadConfig(filename string) error { // We hot swap the config by having the old and new listeners both live at // the same time. This means we create listeners for the new config first, // and then close the old ones after. - stopConfig, err := s.runConfig(*config) + sopConfig, err := s.runConfig(*config) if err != nil { return err } - go s.stopConfig() - s.stopConfig = stopConfig + if err := s.Stop(); err != nil { + return fmt.Errorf("unable to stop old config: %v", err) + } + s.stopConfig = sopConfig return nil } @@ -151,17 +153,17 @@ type listenerSet struct { listenersMu sync.Mutex } -// ListenStream announces on a given TCP network address. Trying to listen on -// the same address twice will result in an error. +// ListenStream announces on a given network address. Trying to listen for stream connections +// on the same address twice will result in an error. func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) { ls.listenersMu.Lock() defer ls.listenersMu.Unlock() - lnKey := "tcp/" + addr + lnKey := "stream/" + addr if _, exists := ls.listenerCloseFuncs[lnKey]; exists { - return nil, fmt.Errorf("listener %s already exists", lnKey) + return nil, fmt.Errorf("stream listener for %s already exists", addr) } - ln, err := ls.manager.ListenStream("tcp", addr) + ln, err := ls.manager.ListenStream(addr) if err != nil { return nil, err } @@ -169,17 +171,17 @@ func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) return ln, nil } -// ListenPacket announces on a given UDP network address. Trying to listen on -// the same address twice will result in an error. +// ListenPacket announces on a given network address. Trying to listen for packet connections +// on the same address twice will result in an error. func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { ls.listenersMu.Lock() defer ls.listenersMu.Unlock() - lnKey := "udp/" + addr + lnKey := "packet/" + addr if _, exists := ls.listenerCloseFuncs[lnKey]; exists { - return nil, fmt.Errorf("listener %s already exists", lnKey) + return nil, fmt.Errorf("packet listener for %s already exists", addr) } - ln, err := ls.manager.ListenPacket("udp", addr) + ln, err := ls.manager.ListenPacket(addr) if err != nil { return nil, err } @@ -189,13 +191,14 @@ func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { // Close closes all the listeners in the set, after which the set can't be used again. func (ls *listenerSet) Close() error { + ls.listenersMu.Lock() + defer ls.listenersMu.Unlock() + for addr, listenerCloseFunc := range ls.listenerCloseFuncs { if err := listenerCloseFunc(); err != nil { return fmt.Errorf("listener on address %s failed to stop: %w", addr, err) } } - ls.listenersMu.Lock() - defer ls.listenersMu.Unlock() ls.listenerCloseFuncs = nil return nil } @@ -205,8 +208,9 @@ func (ls *listenerSet) Len() int { return len(ls.listenerCloseFuncs) } -func (s *SSServer) runConfig(config Config) (func(), error) { +func (s *SSServer) runConfig(config Config) (func() error, error) { startErrCh := make(chan error) + stopErrCh := make(chan error) stopCh := make(chan struct{}) go func() { @@ -214,7 +218,9 @@ func (s *SSServer) runConfig(config Config) (func(), error) { manager: s.lnManager, listenerCloseFuncs: make(map[string]func() error), } - defer lnSet.Close() // This closes all the listeners in the set. + defer func() { + stopErrCh <- lnSet.Close() + }() startErrCh <- func() error { totalCipherCount := len(config.Keys) @@ -305,24 +311,33 @@ func (s *SSServer) runConfig(config Config) (func(), error) { if err != nil { return nil, err } - return func() { + return func() error { logger.Infof("Stopping running config.") // TODO(sbruens): Actually wait for all handlers to be stopped, e.g. by // using a https://pkg.go.dev/sync#WaitGroup. stopCh <- struct{}{} + stopErr := <-stopErrCh + return stopErr }, nil } // Stop stops serving the current config. -func (s *SSServer) Stop() { - go s.stopConfig() +func (s *SSServer) Stop() error { + stopFunc := s.stopConfig + if stopFunc == nil { + return nil + } + if err := stopFunc(); err != nil { + logger.Errorf("Error stopping config: %v", err) + return err + } logger.Info("Stopped all listeners for running config") + return nil } // RunSSServer starts a shadowsocks server running, and returns the server or an error. func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, replayHistory int) (*SSServer, error) { server := &SSServer{ - stopConfig: func() {}, lnManager: service.NewListenerManager(), natTimeout: natTimeout, m: sm, diff --git a/cmd/outline-ss-server/server_test.go b/cmd/outline-ss-server/server_test.go index 2ba0772e..0b7777b2 100644 --- a/cmd/outline-ss-server/server_test.go +++ b/cmd/outline-ss-server/server_test.go @@ -27,5 +27,7 @@ func TestRunSSServer(t *testing.T) { if err != nil { t.Fatalf("RunSSServer() error = %v", err) } - server.Stop() + if err := server.Stop(); err != nil { + t.Errorf("Error while stopping server: %v", err) + } } diff --git a/service/listeners.go b/service/listeners.go index 8f0882ba..788d651f 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -64,25 +64,31 @@ func (t *TCPListener) Addr() net.Addr { return t.ln.Addr() } +type OnCloseFunc func() error + type acceptResponse struct { conn transport.StreamConn err error } -type OnCloseFunc func() error - type virtualStreamListener struct { - listener StreamListener - acceptCh chan acceptResponse + mu sync.Mutex // Mutex to protect access to the channels + addr net.Addr + acceptCh <-chan acceptResponse closeCh chan struct{} + closed bool onCloseFunc OnCloseFunc } var _ StreamListener = (*virtualStreamListener)(nil) func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) { + sl.mu.Lock() + acceptCh := sl.acceptCh + sl.mu.Unlock() + select { - case acceptResponse, ok := <-sl.acceptCh: + case acceptResponse, ok := <-acceptCh: if !ok { return nil, net.ErrClosed } @@ -93,201 +99,270 @@ func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) { } func (sl *virtualStreamListener) Close() error { + sl.mu.Lock() + if sl.closed { + sl.mu.Unlock() + return nil + } + sl.closed = true + sl.acceptCh = nil close(sl.closeCh) - return sl.onCloseFunc() + sl.mu.Unlock() + + if sl.onCloseFunc != nil { + return sl.onCloseFunc() + } + return nil } func (sl *virtualStreamListener) Addr() net.Addr { - return sl.listener.Addr() + return sl.addr +} + +type packetResponse struct { + n int + addr net.Addr + err error + data []byte } type virtualPacketConn struct { net.PacketConn + mu sync.Mutex // Mutex to protect access to the channels + readCh <-chan packetResponse + closeCh chan struct{} + closed bool onCloseFunc OnCloseFunc } -func (spc *virtualPacketConn) Close() error { - return spc.onCloseFunc() -} - -type listenAddr struct { - ln Listener - acceptCh chan acceptResponse - onCloseFunc OnCloseFunc -} +func (pc *virtualPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + pc.mu.Lock() + readCh := pc.readCh + pc.mu.Unlock() -type canCreateStreamListener interface { - NewStreamListener(onCloseFunc OnCloseFunc) StreamListener + select { + case packetResponse, ok := <-readCh: + if !ok { + return 0, nil, net.ErrClosed + } + copy(p, packetResponse.data) + return packetResponse.n, packetResponse.addr, packetResponse.err + case <-pc.closeCh: + return 0, nil, net.ErrClosed + } } -var _ canCreateStreamListener = (*listenAddr)(nil) +func (pc *virtualPacketConn) Close() error { + pc.mu.Lock() + if pc.closed { + pc.mu.Unlock() + return nil + } + pc.closed = true + pc.readCh = nil + close(pc.closeCh) + pc.mu.Unlock() -// NewStreamListener creates a new [StreamListener]. -func (la *listenAddr) NewStreamListener(onCloseFunc OnCloseFunc) StreamListener { - if ln, ok := la.ln.(StreamListener); ok { - return &virtualStreamListener{ - listener: ln, - acceptCh: la.acceptCh, - closeCh: make(chan struct{}), - onCloseFunc: onCloseFunc, - } + if pc.onCloseFunc != nil { + return pc.onCloseFunc() } return nil } -type canCreatePacketListener interface { - NewPacketListener(onCloseFunc OnCloseFunc) net.PacketConn +// MultiListener manages shared listeners. +type MultiListener[T Listener] interface { + // Acquire creates a new listener from the shared listener. Listeners can overlap + // one another (e.g. during config changes the new config is started before the + // old config is destroyed), which is done by creating virtual listeners that wrap + // the shared listener. These virtual listeners do not actually close the + // underlying socket until all uses of the shared listener have been closed. + Acquire() (T, error) } -var _ canCreatePacketListener = (*listenAddr)(nil) +type multiStreamListener struct { + mu sync.Mutex + addr string + ln RefCount[StreamListener] + acceptCh chan acceptResponse + onCloseFunc OnCloseFunc +} -// NewPacketListener creates a new [net.PacketConn]. -func (cl *listenAddr) NewPacketListener(onCloseFunc OnCloseFunc) net.PacketConn { - if ln, ok := cl.ln.(net.PacketConn); ok { - return &virtualPacketConn{ - PacketConn: ln, - onCloseFunc: onCloseFunc, - } +// NewMultiStreamListener creates a new stream-based [MultiListener]. +func NewMultiStreamListener(addr string, onCloseFunc OnCloseFunc) MultiListener[StreamListener] { + return &multiStreamListener{ + addr: addr, + onCloseFunc: onCloseFunc, } - return nil } -func (la *listenAddr) Close() error { - if err := la.ln.Close(); err != nil { - return err +func (m *multiStreamListener) Acquire() (StreamListener, error) { + refCount, err := func() (RefCount[StreamListener], error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.ln == nil { + tcpAddr, err := net.ResolveTCPAddr("tcp", m.addr) + if err != nil { + return nil, err + } + ln, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, err + } + sl := &TCPListener{ln} + m.ln = NewRefCount[StreamListener](sl, m.onCloseFunc) + m.acceptCh = make(chan acceptResponse) + go func() { + for { + conn, err := sl.AcceptStream() + if errors.Is(err, net.ErrClosed) { + close(m.acceptCh) + return + } + m.acceptCh <- acceptResponse{conn, err} + } + }() + } + return m.ln, nil + }() + if err != nil { + return nil, err } - return la.onCloseFunc() -} -// ListenerManager holds and manages the state of shared listeners. -type ListenerManager interface { - // ListenStream creates a new stream listener for a given network and address. - // - // Listeners can overlap one another, because during config changes the new - // config is started before the old config is destroyed. This is done by using - // reusable listener wrappers, which do not actually close the underlying socket - // until all uses of the shared listener have been closed. - ListenStream(network string, addr string) (StreamListener, error) - - // ListenPacket creates a new packet listener for a given network and address. - // - // See notes on [ListenStream]. - ListenPacket(network string, addr string) (net.PacketConn, error) + sl := refCount.Acquire() + return &virtualStreamListener{ + addr: sl.Addr(), + acceptCh: m.acceptCh, + closeCh: make(chan struct{}), + onCloseFunc: refCount.Close, + }, nil } -type listenerManager struct { - listeners map[string]RefCount[Listener] - listenersMu sync.Mutex +type multiPacketListener struct { + mu sync.Mutex + addr string + pc RefCount[net.PacketConn] + readCh chan packetResponse + onCloseFunc OnCloseFunc } -// NewListenerManager creates a new [ListenerManger]. -func NewListenerManager() ListenerManager { - return &listenerManager{ - listeners: make(map[string]RefCount[Listener]), +// NewMultiPacketListener creates a new packet-based [MultiListener]. +func NewMultiPacketListener(addr string, onCloseFunc OnCloseFunc) MultiListener[net.PacketConn] { + return &multiPacketListener{ + addr: addr, + onCloseFunc: onCloseFunc, } } -func (m *listenerManager) newStreamListener(network string, addr string) (Listener, error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - ln, err := net.ListenTCP(network, tcpAddr) - if err != nil { - return nil, err - } - streamLn := &TCPListener{ln} - lnAddr := &listenAddr{ - ln: streamLn, - acceptCh: make(chan acceptResponse), - onCloseFunc: func() error { - m.delete(listenerKey(network, addr)) - return nil - }, - } - go func() { - for { - conn, err := streamLn.AcceptStream() - if errors.Is(err, net.ErrClosed) { - close(lnAddr.acceptCh) - return +func (m *multiPacketListener) Acquire() (net.PacketConn, error) { + refCount, err := func() (RefCount[net.PacketConn], error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.pc == nil { + pc, err := net.ListenPacket("udp", m.addr) + if err != nil { + return nil, err } - lnAddr.acceptCh <- acceptResponse{conn, err} + m.pc = NewRefCount(pc, m.onCloseFunc) + m.readCh = make(chan packetResponse) + go func() { + for { + buffer := make([]byte, serverUDPBufferSize) + n, addr, err := pc.ReadFrom(buffer) + if err != nil { + close(m.readCh) + return + } + m.readCh <- packetResponse{n: n, addr: addr, err: err, data: buffer[:n]} + } + }() } + return m.pc, nil }() - return lnAddr, nil -} - -func (m *listenerManager) newPacketListener(network string, addr string) (Listener, error) { - pc, err := net.ListenPacket(network, addr) if err != nil { return nil, err } - return &listenAddr{ - ln: pc, - onCloseFunc: func() error { - m.delete(listenerKey(network, addr)) - return nil - }, + + pc := refCount.Acquire() + return &virtualPacketConn{ + PacketConn: pc, + readCh: m.readCh, + closeCh: make(chan struct{}), + onCloseFunc: refCount.Close, }, nil } -func (m *listenerManager) getListener(network string, addr string) (RefCount[Listener], error) { - m.listenersMu.Lock() - defer m.listenersMu.Unlock() +// ListenerManager holds the state of shared listeners. +type ListenerManager interface { + // ListenStream creates a new stream listener for a given address. + ListenStream(addr string) (StreamListener, error) - lnKey := listenerKey(network, addr) - if lnRefCount, exists := m.listeners[lnKey]; exists { - return lnRefCount.Acquire(), nil - } + // ListenPacket creates a new packet listener for a given address. + ListenPacket(addr string) (net.PacketConn, error) +} - var ( - ln Listener - err error - ) - if network == "tcp" { - ln, err = m.newStreamListener(network, addr) - } else { - ln, err = m.newPacketListener(network, addr) - } - if err != nil { - return nil, err - } - lnRefCount := NewRefCount(ln) - m.listeners[lnKey] = lnRefCount - return lnRefCount, nil +type listenerManager struct { + streamListeners map[string]MultiListener[StreamListener] + packetListeners map[string]MultiListener[net.PacketConn] + mu sync.Mutex } -func (m *listenerManager) ListenStream(network string, addr string) (StreamListener, error) { - lnRefCount, err := m.getListener(network, addr) - if err != nil { - return nil, err - } - if lnAddr, ok := lnRefCount.Get().(canCreateStreamListener); ok { - return lnAddr.NewStreamListener(lnRefCount.Close), nil +// NewListenerManager creates a new [ListenerManger]. +func NewListenerManager() ListenerManager { + return &listenerManager{ + streamListeners: make(map[string]MultiListener[StreamListener]), + packetListeners: make(map[string]MultiListener[net.PacketConn]), } - return nil, fmt.Errorf("unable to create stream listener for %s/%s", network, addr) } -func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketConn, error) { - lnRefCount, err := m.getListener(network, addr) +func (m *listenerManager) ListenStream(addr string) (StreamListener, error) { + m.mu.Lock() + defer m.mu.Unlock() + + streamLn, exists := m.streamListeners[addr] + if !exists { + streamLn = NewMultiStreamListener( + addr, + func() error { + m.mu.Lock() + delete(m.streamListeners, addr) + m.mu.Unlock() + return nil + }, + ) + m.streamListeners[addr] = streamLn + } + ln, err := streamLn.Acquire() if err != nil { - return nil, err + return nil, fmt.Errorf("unable to create stream listener for %s: %v", addr, err) } - if lnAddr, ok := lnRefCount.Get().(canCreatePacketListener); ok { - return lnAddr.NewPacketListener(lnRefCount.Close), nil + return ln, nil +} + +func (m *listenerManager) ListenPacket(addr string) (net.PacketConn, error) { + m.mu.Lock() + defer m.mu.Unlock() + + packetLn, exists := m.packetListeners[addr] + if !exists { + packetLn = NewMultiPacketListener( + addr, + func() error { + m.mu.Lock() + delete(m.packetListeners, addr) + m.mu.Unlock() + return nil + }, + ) + m.packetListeners[addr] = packetLn } - return nil, fmt.Errorf("unable to create packet listener for %s/%s", network, addr) -} - -func (m *listenerManager) delete(key string) { - m.listenersMu.Lock() - delete(m.listeners, key) - m.listenersMu.Unlock() -} -func listenerKey(network string, addr string) string { - return network + "/" + addr + ln, err := packetLn.Acquire() + if err != nil { + return nil, fmt.Errorf("unable to create packet listener for %s: %v", addr, err) + } + return ln, nil } // RefCount is an atomic reference counter that can be used to track a shared @@ -296,38 +371,44 @@ type RefCount[T io.Closer] interface { io.Closer // Acquire increases the ref count and returns the wrapped object. - Acquire() RefCount[T] - - Get() T + Acquire() T } type refCount[T io.Closer] struct { - count *atomic.Int32 - value T + mu sync.Mutex + count *atomic.Int32 + value T + onCloseFunc OnCloseFunc } -func NewRefCount[T io.Closer](value T) RefCount[T] { - res := &refCount[T]{ - count: &atomic.Int32{}, - value: value, +func NewRefCount[T io.Closer](value T, onCloseFunc OnCloseFunc) RefCount[T] { + r := &refCount[T]{ + count: &atomic.Int32{}, + value: value, + onCloseFunc: onCloseFunc, } - res.count.Store(1) - return res + return r +} + +func (r refCount[T]) Acquire() T { + r.count.Add(1) + return r.value } func (r refCount[T]) Close() error { + // Lock to prevent someone from acquiring while we close the value. + r.mu.Lock() + defer r.mu.Unlock() + if count := r.count.Add(-1); count == 0 { - return r.value.Close() + err := r.value.Close() + if err != nil { + return err + } + if r.onCloseFunc != nil { + return r.onCloseFunc() + } + return nil } - return nil } - -func (r refCount[T]) Acquire() RefCount[T] { - r.count.Add(1) - return r -} - -func (r refCount[T]) Get() T { - return r.value -} diff --git a/service/listeners_test.go b/service/listeners_test.go index 384126ad..d627ec1a 100644 --- a/service/listeners_test.go +++ b/service/listeners_test.go @@ -24,7 +24,7 @@ import ( func TestListenerManagerStreamListenerEarlyClose(t *testing.T) { m := NewListenerManager() - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) ln.Close() @@ -47,33 +47,32 @@ func writeTestPayload(ln StreamListener) error { func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) { m := NewListenerManager() - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) - ln2, err := m.ListenStream("tcp", "127.0.0.1:0") + ln2, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) - // Close only the first listener. ln.Close() + done := make(chan struct{}) go func() { ln2.AcceptStream() done <- struct{}{} }() - err = writeTestPayload(ln2) - require.NoError(t, err) + require.NoError(t, err) <-done } func TestListenerManagerStreamListenerCreatesListenerOnDemand(t *testing.T) { m := NewListenerManager() // Create a listener and immediately close it. - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) ln.Close() // Now create another listener on the same address. - ln2, err := m.ListenStream("tcp", "127.0.0.1:0") + ln2, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) done := make(chan struct{}) @@ -82,8 +81,64 @@ func TestListenerManagerStreamListenerCreatesListenerOnDemand(t *testing.T) { done <- struct{}{} }() err = writeTestPayload(ln2) + + require.NoError(t, err) + <-done +} + +func TestListenerManagerPacketListenerEarlyClose(t *testing.T) { + m := NewListenerManager() + pc, err := m.ListenPacket("127.0.0.1:0") + require.NoError(t, err) + + pc.Close() + _, _, readErr := pc.ReadFrom(nil) + _, writeErr := pc.WriteTo(nil, &net.UDPAddr{}) + + require.ErrorIs(t, readErr, net.ErrClosed) + require.ErrorIs(t, writeErr, net.ErrClosed) +} + +func TestListenerManagerPacketListenerNotClosedIfStillInUse(t *testing.T) { + m := NewListenerManager() + pc, err := m.ListenPacket("127.0.0.1:0") + require.NoError(t, err) + pc2, err := m.ListenPacket("127.0.0.1:0") + require.NoError(t, err) + // Close only the first listener. + pc.Close() + + done := make(chan struct{}) + go func() { + _, _, readErr := pc2.ReadFrom(nil) + require.NoError(t, readErr) + done <- struct{}{} + }() + _, err = pc.WriteTo(nil, pc2.LocalAddr()) + require.NoError(t, err) + <-done +} +func TestListenerManagerPacketListenerCreatesListenerOnDemand(t *testing.T) { + m := NewListenerManager() + // Create a listener and immediately close it. + pc, err := m.ListenPacket("127.0.0.1:0") + require.NoError(t, err) + pc.Close() + // Now create another listener on the same address. + pc2, err := m.ListenPacket("127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + _, _, readErr := pc2.ReadFrom(nil) + require.NoError(t, readErr) + done <- struct{}{} + }() + _, err = pc2.WriteTo(nil, pc2.LocalAddr()) + + require.NoError(t, err) <-done } @@ -97,17 +152,27 @@ func (t *testRefCount) Close() error { } func TestRefCount(t *testing.T) { - var done bool - rc := NewRefCount[*testRefCount](&testRefCount{ - onCloseFunc: func() { - done = true + var objectCloseDone bool + var onCloseFuncDone bool + rc := NewRefCount[*testRefCount]( + &testRefCount{ + onCloseFunc: func() { + objectCloseDone = true + }, }, - }) + func() error { + onCloseFuncDone = true + return nil + }, + ) + rc.Acquire() rc.Acquire() require.NoError(t, rc.Close()) - require.False(t, done) + require.False(t, objectCloseDone) + require.False(t, onCloseFuncDone) require.NoError(t, rc.Close()) - require.True(t, done) + require.True(t, objectCloseDone) + require.True(t, onCloseFuncDone) }