From fe8bbdda37bbb3927761690bf5ded40ed7b11ff7 Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 5 Aug 2024 15:22:13 -0400 Subject: [PATCH 01/10] Address review comments. --- cmd/outline-ss-server/main.go | 20 +++++++-------- service/listeners.go | 47 +++++++++++++++++++---------------- service/listeners_test.go | 10 ++++---- 3 files changed, 40 insertions(+), 37 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index ec6b66b0..f180044c 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -101,17 +101,17 @@ type listenerSet struct { listenersMu sync.Mutex } -// ListenStream announces on a given TCP network address. Trying to listen on -// the same address twice will result in an error. +// ListenStream announces on a given network address. Trying to listen for stream connections +// on the same address twice will result in an error. func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) { ls.listenersMu.Lock() defer ls.listenersMu.Unlock() - lnKey := "tcp/" + addr + lnKey := "stream-" + addr if _, exists := ls.listenerCloseFuncs[lnKey]; exists { - return nil, fmt.Errorf("listener %s already exists", lnKey) + return nil, fmt.Errorf("stream listener for %s already exists", addr) } - ln, err := ls.manager.ListenStream("tcp", addr) + ln, err := ls.manager.ListenStream(addr) if err != nil { return nil, err } @@ -119,17 +119,17 @@ func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) return ln, nil } -// ListenPacket announces on a given UDP network address. Trying to listen on -// the same address twice will result in an error. +// ListenPacket announces on a given network address. Trying to listen for packet connections +// on the same address twice will result in an error. func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { ls.listenersMu.Lock() defer ls.listenersMu.Unlock() - lnKey := "udp/" + addr + lnKey := "packet-" + addr if _, exists := ls.listenerCloseFuncs[lnKey]; exists { - return nil, fmt.Errorf("listener %s already exists", lnKey) + return nil, fmt.Errorf("packet listener for %s already exists", addr) } - ln, err := ls.manager.ListenPacket("udp", addr) + ln, err := ls.manager.ListenPacket(addr) if err != nil { return nil, err } diff --git a/service/listeners.go b/service/listeners.go index 8f0882ba..d8d58192 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -72,8 +72,8 @@ type acceptResponse struct { type OnCloseFunc func() error type virtualStreamListener struct { - listener StreamListener - acceptCh chan acceptResponse + addr net.Addr + acceptCh <-chan acceptResponse closeCh chan struct{} onCloseFunc OnCloseFunc } @@ -93,12 +93,13 @@ func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) { } func (sl *virtualStreamListener) Close() error { + sl.acceptCh = nil close(sl.closeCh) return sl.onCloseFunc() } func (sl *virtualStreamListener) Addr() net.Addr { - return sl.listener.Addr() + return sl.addr } type virtualPacketConn struct { @@ -126,7 +127,7 @@ var _ canCreateStreamListener = (*listenAddr)(nil) func (la *listenAddr) NewStreamListener(onCloseFunc OnCloseFunc) StreamListener { if ln, ok := la.ln.(StreamListener); ok { return &virtualStreamListener{ - listener: ln, + addr: ln.Addr(), acceptCh: la.acceptCh, closeCh: make(chan struct{}), onCloseFunc: onCloseFunc, @@ -161,18 +162,18 @@ func (la *listenAddr) Close() error { // ListenerManager holds and manages the state of shared listeners. type ListenerManager interface { - // ListenStream creates a new stream listener for a given network and address. + // ListenStream creates a new stream listener for a given address. // // Listeners can overlap one another, because during config changes the new // config is started before the old config is destroyed. This is done by using // reusable listener wrappers, which do not actually close the underlying socket // until all uses of the shared listener have been closed. - ListenStream(network string, addr string) (StreamListener, error) + ListenStream(addr string) (StreamListener, error) - // ListenPacket creates a new packet listener for a given network and address. + // ListenPacket creates a new packet listener for a given address. // // See notes on [ListenStream]. - ListenPacket(network string, addr string) (net.PacketConn, error) + ListenPacket(addr string) (net.PacketConn, error) } type listenerManager struct { @@ -187,12 +188,12 @@ func NewListenerManager() ListenerManager { } } -func (m *listenerManager) newStreamListener(network string, addr string) (Listener, error) { +func (m *listenerManager) newStreamListener(addr string) (Listener, error) { tcpAddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { return nil, err } - ln, err := net.ListenTCP(network, tcpAddr) + ln, err := net.ListenTCP("tcp", tcpAddr) if err != nil { return nil, err } @@ -201,7 +202,7 @@ func (m *listenerManager) newStreamListener(network string, addr string) (Listen ln: streamLn, acceptCh: make(chan acceptResponse), onCloseFunc: func() error { - m.delete(listenerKey(network, addr)) + m.delete(listenerKey("tcp", addr)) return nil }, } @@ -218,15 +219,15 @@ func (m *listenerManager) newStreamListener(network string, addr string) (Listen return lnAddr, nil } -func (m *listenerManager) newPacketListener(network string, addr string) (Listener, error) { - pc, err := net.ListenPacket(network, addr) +func (m *listenerManager) newPacketListener(addr string) (Listener, error) { + pc, err := net.ListenPacket("udp", addr) if err != nil { return nil, err } return &listenAddr{ ln: pc, onCloseFunc: func() error { - m.delete(listenerKey(network, addr)) + m.delete(listenerKey("udp", addr)) return nil }, }, nil @@ -246,9 +247,11 @@ func (m *listenerManager) getListener(network string, addr string) (RefCount[Lis err error ) if network == "tcp" { - ln, err = m.newStreamListener(network, addr) + ln, err = m.newStreamListener(addr) + } else if network == "udp" { + ln, err = m.newPacketListener(addr) } else { - ln, err = m.newPacketListener(network, addr) + return nil, fmt.Errorf("unable to get listener for unsupported network %s", network) } if err != nil { return nil, err @@ -258,26 +261,26 @@ func (m *listenerManager) getListener(network string, addr string) (RefCount[Lis return lnRefCount, nil } -func (m *listenerManager) ListenStream(network string, addr string) (StreamListener, error) { - lnRefCount, err := m.getListener(network, addr) +func (m *listenerManager) ListenStream(addr string) (StreamListener, error) { + lnRefCount, err := m.getListener("tcp", addr) if err != nil { return nil, err } if lnAddr, ok := lnRefCount.Get().(canCreateStreamListener); ok { return lnAddr.NewStreamListener(lnRefCount.Close), nil } - return nil, fmt.Errorf("unable to create stream listener for %s/%s", network, addr) + return nil, fmt.Errorf("unable to create stream listener for %s", addr) } -func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketConn, error) { - lnRefCount, err := m.getListener(network, addr) +func (m *listenerManager) ListenPacket(addr string) (net.PacketConn, error) { + lnRefCount, err := m.getListener("udp", addr) if err != nil { return nil, err } if lnAddr, ok := lnRefCount.Get().(canCreatePacketListener); ok { return lnAddr.NewPacketListener(lnRefCount.Close), nil } - return nil, fmt.Errorf("unable to create packet listener for %s/%s", network, addr) + return nil, fmt.Errorf("unable to create packet listener for %s", addr) } func (m *listenerManager) delete(key string) { diff --git a/service/listeners_test.go b/service/listeners_test.go index 384126ad..da5aaa1e 100644 --- a/service/listeners_test.go +++ b/service/listeners_test.go @@ -24,7 +24,7 @@ import ( func TestListenerManagerStreamListenerEarlyClose(t *testing.T) { m := NewListenerManager() - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) ln.Close() @@ -47,9 +47,9 @@ func writeTestPayload(ln StreamListener) error { func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) { m := NewListenerManager() - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) - ln2, err := m.ListenStream("tcp", "127.0.0.1:0") + ln2, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) // Close only the first listener. @@ -69,11 +69,11 @@ func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) { func TestListenerManagerStreamListenerCreatesListenerOnDemand(t *testing.T) { m := NewListenerManager() // Create a listener and immediately close it. - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) ln.Close() // Now create another listener on the same address. - ln2, err := m.ListenStream("tcp", "127.0.0.1:0") + ln2, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) done := make(chan struct{}) From 36a0a1d9f43c25e023b7c9b70fd54dcdae2dcd5e Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 5 Aug 2024 18:32:20 -0400 Subject: [PATCH 02/10] Use a mutex to ensure another user doesn't acquire a new closer while we're closing it. --- service/listeners.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index d8d58192..648f315e 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -305,24 +305,28 @@ type RefCount[T io.Closer] interface { } type refCount[T io.Closer] struct { + mu sync.Mutex count *atomic.Int32 value T } func NewRefCount[T io.Closer](value T) RefCount[T] { - res := &refCount[T]{ + r := &refCount[T]{ count: &atomic.Int32{}, value: value, } - res.count.Store(1) - return res + r.count.Store(1) + return r } func (r refCount[T]) Close() error { + // Lock to prevent someone from acquiring while we close the value. + r.mu.Lock() + defer r.mu.Unlock() + if count := r.count.Add(-1); count == 0 { return r.value.Close() } - return nil } From aeb2652fb3da4ea7339c88fcba7e32c5b91a5454 Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 6 Aug 2024 10:13:00 -0400 Subject: [PATCH 03/10] Move mutex up. --- cmd/outline-ss-server/main.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index f180044c..55a11acc 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -139,13 +139,14 @@ func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { // Close closes all the listeners in the set, after which the set can't be used again. func (ls *listenerSet) Close() error { + ls.listenersMu.Lock() + defer ls.listenersMu.Unlock() + for addr, listenerCloseFunc := range ls.listenerCloseFuncs { if err := listenerCloseFunc(); err != nil { return fmt.Errorf("listener on address %s failed to stop: %w", addr, err) } } - ls.listenersMu.Lock() - defer ls.listenersMu.Unlock() ls.listenerCloseFuncs = nil return nil } From 8873b107083fedfe2b6627f326405c039177e4a3 Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 6 Aug 2024 17:33:18 -0400 Subject: [PATCH 04/10] Manage the ref counting next to the listener creation. --- service/listeners.go | 310 +++++++++++++++++++------------------- service/listeners_test.go | 24 ++- 2 files changed, 170 insertions(+), 164 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 648f315e..91aa877f 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -95,7 +95,10 @@ func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) { func (sl *virtualStreamListener) Close() error { sl.acceptCh = nil close(sl.closeCh) - return sl.onCloseFunc() + if sl.onCloseFunc != nil { + return sl.onCloseFunc() + } + return nil } func (sl *virtualStreamListener) Addr() net.Addr { @@ -108,189 +111,180 @@ type virtualPacketConn struct { } func (spc *virtualPacketConn) Close() error { - return spc.onCloseFunc() + if spc.onCloseFunc != nil { + return spc.onCloseFunc() + } + return nil } -type listenAddr struct { - ln Listener +// MultiListener manages shared listeners. +type MultiListener[T Listener] interface { + // Acquire creates a new listener from the shared listener. Listeners can overlap + // one another (e.g. during config changes the new config is started before the + // old config is destroyed), which is done by creating virtual listeners that wrap + // the shared listener. These virtual listeners do not actually close the + // underlying socket until all uses of the shared listener have been closed. + Acquire() (T, error) +} + +type multiStreamListener struct { + mu sync.Mutex + addr string + ln RefCount[StreamListener] acceptCh chan acceptResponse onCloseFunc OnCloseFunc } -type canCreateStreamListener interface { - NewStreamListener(onCloseFunc OnCloseFunc) StreamListener +// NewMultiStreamListener creates a new stream-based [MultiListener]. +func NewMultiStreamListener(addr string, onCloseFunc OnCloseFunc) MultiListener[StreamListener] { + return &multiStreamListener{ + addr: addr, + acceptCh: make(chan acceptResponse), + onCloseFunc: onCloseFunc, + } } -var _ canCreateStreamListener = (*listenAddr)(nil) +func (m *multiStreamListener) Acquire() (StreamListener, error) { + m.mu.Lock() + defer m.mu.Unlock() -// NewStreamListener creates a new [StreamListener]. -func (la *listenAddr) NewStreamListener(onCloseFunc OnCloseFunc) StreamListener { - if ln, ok := la.ln.(StreamListener); ok { - return &virtualStreamListener{ - addr: ln.Addr(), - acceptCh: la.acceptCh, - closeCh: make(chan struct{}), - onCloseFunc: onCloseFunc, + var sl StreamListener + if m.ln == nil { + tcpAddr, err := net.ResolveTCPAddr("tcp", m.addr) + if err != nil { + return nil, err + } + ln, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, err } + sl = &TCPListener{ln} + m.ln = NewRefCount(sl, m.onCloseFunc) + go func() { + for { + conn, err := sl.AcceptStream() + if errors.Is(err, net.ErrClosed) { + close(m.acceptCh) + return + } + m.acceptCh <- acceptResponse{conn, err} + } + }() } - return nil -} -type canCreatePacketListener interface { - NewPacketListener(onCloseFunc OnCloseFunc) net.PacketConn + sl = m.ln.Acquire() + return &virtualStreamListener{ + addr: sl.Addr(), + acceptCh: m.acceptCh, + closeCh: make(chan struct{}), + onCloseFunc: m.ln.Close, + }, nil } -var _ canCreatePacketListener = (*listenAddr)(nil) +type multiPacketListener struct { + mu sync.Mutex + addr string + pc RefCount[net.PacketConn] + onCloseFunc OnCloseFunc +} -// NewPacketListener creates a new [net.PacketConn]. -func (cl *listenAddr) NewPacketListener(onCloseFunc OnCloseFunc) net.PacketConn { - if ln, ok := cl.ln.(net.PacketConn); ok { - return &virtualPacketConn{ - PacketConn: ln, - onCloseFunc: onCloseFunc, - } +// NewMultiPacketListener creates a new packet-based [MultiListener]. +func NewMultiPacketListener(addr string, onCloseFunc OnCloseFunc) MultiListener[net.PacketConn] { + return &multiPacketListener{ + addr: addr, + onCloseFunc: onCloseFunc, } - return nil } -func (la *listenAddr) Close() error { - if err := la.ln.Close(); err != nil { - return err +func (m *multiPacketListener) Acquire() (net.PacketConn, error) { + m.mu.Lock() + defer m.mu.Unlock() + + var pc net.PacketConn + if m.pc == nil { + pc, err := net.ListenPacket("udp", m.addr) + if err != nil { + return nil, err + } + m.pc = NewRefCount(pc, m.onCloseFunc) } - return la.onCloseFunc() + pc = m.pc.Acquire() + return &virtualPacketConn{ + PacketConn: pc, + onCloseFunc: m.pc.Close, + }, nil } -// ListenerManager holds and manages the state of shared listeners. +// ListenerManager holds the state of shared listeners. type ListenerManager interface { // ListenStream creates a new stream listener for a given address. - // - // Listeners can overlap one another, because during config changes the new - // config is started before the old config is destroyed. This is done by using - // reusable listener wrappers, which do not actually close the underlying socket - // until all uses of the shared listener have been closed. ListenStream(addr string) (StreamListener, error) // ListenPacket creates a new packet listener for a given address. - // - // See notes on [ListenStream]. ListenPacket(addr string) (net.PacketConn, error) } type listenerManager struct { - listeners map[string]RefCount[Listener] - listenersMu sync.Mutex + streamListeners map[string]MultiListener[StreamListener] + packetListeners map[string]MultiListener[net.PacketConn] + mu sync.Mutex } // NewListenerManager creates a new [ListenerManger]. func NewListenerManager() ListenerManager { return &listenerManager{ - listeners: make(map[string]RefCount[Listener]), - } -} - -func (m *listenerManager) newStreamListener(addr string) (Listener, error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - ln, err := net.ListenTCP("tcp", tcpAddr) - if err != nil { - return nil, err - } - streamLn := &TCPListener{ln} - lnAddr := &listenAddr{ - ln: streamLn, - acceptCh: make(chan acceptResponse), - onCloseFunc: func() error { - m.delete(listenerKey("tcp", addr)) - return nil - }, + streamListeners: make(map[string]MultiListener[StreamListener]), + packetListeners: make(map[string]MultiListener[net.PacketConn]), } - go func() { - for { - conn, err := streamLn.AcceptStream() - if errors.Is(err, net.ErrClosed) { - close(lnAddr.acceptCh) - return - } - lnAddr.acceptCh <- acceptResponse{conn, err} - } - }() - return lnAddr, nil } -func (m *listenerManager) newPacketListener(addr string) (Listener, error) { - pc, err := net.ListenPacket("udp", addr) - if err != nil { - return nil, err - } - return &listenAddr{ - ln: pc, - onCloseFunc: func() error { - m.delete(listenerKey("udp", addr)) - return nil - }, - }, nil -} - -func (m *listenerManager) getListener(network string, addr string) (RefCount[Listener], error) { - m.listenersMu.Lock() - defer m.listenersMu.Unlock() - - lnKey := listenerKey(network, addr) - if lnRefCount, exists := m.listeners[lnKey]; exists { - return lnRefCount.Acquire(), nil - } - - var ( - ln Listener - err error - ) - if network == "tcp" { - ln, err = m.newStreamListener(addr) - } else if network == "udp" { - ln, err = m.newPacketListener(addr) - } else { - return nil, fmt.Errorf("unable to get listener for unsupported network %s", network) +func (m *listenerManager) ListenStream(addr string) (StreamListener, error) { + m.mu.Lock() + defer m.mu.Unlock() + + streamLn, exists := m.streamListeners[addr] + if !exists { + streamLn = NewMultiStreamListener( + addr, + func() error { + m.mu.Lock() + delete(m.streamListeners, addr) + m.mu.Unlock() + return nil + }, + ) + m.streamListeners[addr] = streamLn } + ln, err := streamLn.Acquire() if err != nil { - return nil, err + return nil, fmt.Errorf("unable to create stream listener for %s: %v", addr, err) } - lnRefCount := NewRefCount(ln) - m.listeners[lnKey] = lnRefCount - return lnRefCount, nil + return ln, nil } -func (m *listenerManager) ListenStream(addr string) (StreamListener, error) { - lnRefCount, err := m.getListener("tcp", addr) - if err != nil { - return nil, err - } - if lnAddr, ok := lnRefCount.Get().(canCreateStreamListener); ok { - return lnAddr.NewStreamListener(lnRefCount.Close), nil +func (m *listenerManager) ListenPacket(addr string) (net.PacketConn, error) { + m.mu.Lock() + defer m.mu.Unlock() + + packetLn, exists := m.packetListeners[addr] + if !exists { + packetLn = NewMultiPacketListener( + addr, + func() error { + m.mu.Lock() + delete(m.packetListeners, addr) + m.mu.Unlock() + return nil + }, + ) + m.packetListeners[addr] = packetLn } - return nil, fmt.Errorf("unable to create stream listener for %s", addr) -} -func (m *listenerManager) ListenPacket(addr string) (net.PacketConn, error) { - lnRefCount, err := m.getListener("udp", addr) + ln, err := packetLn.Acquire() if err != nil { - return nil, err - } - if lnAddr, ok := lnRefCount.Get().(canCreatePacketListener); ok { - return lnAddr.NewPacketListener(lnRefCount.Close), nil + return nil, fmt.Errorf("unable to create packet listener for %s: %v", addr, err) } - return nil, fmt.Errorf("unable to create packet listener for %s", addr) -} - -func (m *listenerManager) delete(key string) { - m.listenersMu.Lock() - delete(m.listeners, key) - m.listenersMu.Unlock() -} - -func listenerKey(network string, addr string) string { - return network + "/" + addr + return ln, nil } // RefCount is an atomic reference counter that can be used to track a shared @@ -299,42 +293,44 @@ type RefCount[T io.Closer] interface { io.Closer // Acquire increases the ref count and returns the wrapped object. - Acquire() RefCount[T] - - Get() T + Acquire() T } type refCount[T io.Closer] struct { - mu sync.Mutex - count *atomic.Int32 - value T + mu sync.Mutex + count *atomic.Int32 + value T + onCloseFunc OnCloseFunc } -func NewRefCount[T io.Closer](value T) RefCount[T] { +func NewRefCount[T io.Closer](value T, onCloseFunc OnCloseFunc) RefCount[T] { r := &refCount[T]{ - count: &atomic.Int32{}, - value: value, + count: &atomic.Int32{}, + value: value, + onCloseFunc: onCloseFunc, } - r.count.Store(1) return r } +func (r refCount[T]) Acquire() T { + r.count.Add(1) + return r.value +} + func (r refCount[T]) Close() error { // Lock to prevent someone from acquiring while we close the value. r.mu.Lock() defer r.mu.Unlock() if count := r.count.Add(-1); count == 0 { - return r.value.Close() + err := r.value.Close() + if err != nil { + return err + } + if r.onCloseFunc != nil { + return r.onCloseFunc() + } + return nil } return nil } - -func (r refCount[T]) Acquire() RefCount[T] { - r.count.Add(1) - return r -} - -func (r refCount[T]) Get() T { - return r.value -} diff --git a/service/listeners_test.go b/service/listeners_test.go index da5aaa1e..0a840cff 100644 --- a/service/listeners_test.go +++ b/service/listeners_test.go @@ -97,17 +97,27 @@ func (t *testRefCount) Close() error { } func TestRefCount(t *testing.T) { - var done bool - rc := NewRefCount[*testRefCount](&testRefCount{ - onCloseFunc: func() { - done = true + var objectCloseDone bool + var onCloseFuncDone bool + rc := NewRefCount[*testRefCount]( + &testRefCount{ + onCloseFunc: func() { + objectCloseDone = true + }, }, - }) + func() error { + onCloseFuncDone = true + return nil + }, + ) + rc.Acquire() rc.Acquire() require.NoError(t, rc.Close()) - require.False(t, done) + require.False(t, objectCloseDone) + require.False(t, onCloseFuncDone) require.NoError(t, rc.Close()) - require.True(t, done) + require.True(t, objectCloseDone) + require.True(t, onCloseFuncDone) } From 899d13d80af21ed132b670916e8ec3fb22e1f762 Mon Sep 17 00:00:00 2001 From: sbruens Date: Tue, 6 Aug 2024 17:57:12 -0400 Subject: [PATCH 05/10] Do the lazy initialization inside an anonymous function. --- service/listeners.go | 87 +++++++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 38 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 91aa877f..9814adca 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -139,45 +139,50 @@ type multiStreamListener struct { func NewMultiStreamListener(addr string, onCloseFunc OnCloseFunc) MultiListener[StreamListener] { return &multiStreamListener{ addr: addr, - acceptCh: make(chan acceptResponse), onCloseFunc: onCloseFunc, } } func (m *multiStreamListener) Acquire() (StreamListener, error) { - m.mu.Lock() - defer m.mu.Unlock() - - var sl StreamListener - if m.ln == nil { - tcpAddr, err := net.ResolveTCPAddr("tcp", m.addr) - if err != nil { - return nil, err - } - ln, err := net.ListenTCP("tcp", tcpAddr) - if err != nil { - return nil, err - } - sl = &TCPListener{ln} - m.ln = NewRefCount(sl, m.onCloseFunc) - go func() { - for { - conn, err := sl.AcceptStream() - if errors.Is(err, net.ErrClosed) { - close(m.acceptCh) - return - } - m.acceptCh <- acceptResponse{conn, err} + refCount, err := func() (RefCount[StreamListener], error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.ln == nil { + tcpAddr, err := net.ResolveTCPAddr("tcp", m.addr) + if err != nil { + return nil, err } - }() + ln, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, err + } + sl := &TCPListener{ln} + m.ln = NewRefCount[StreamListener](sl, m.onCloseFunc) + m.acceptCh = make(chan acceptResponse) + go func() { + for { + conn, err := sl.AcceptStream() + if errors.Is(err, net.ErrClosed) { + close(m.acceptCh) + return + } + m.acceptCh <- acceptResponse{conn, err} + } + }() + } + return m.ln, nil + }() + if err != nil { + return nil, err } - sl = m.ln.Acquire() + sl := refCount.Acquire() return &virtualStreamListener{ addr: sl.Addr(), acceptCh: m.acceptCh, closeCh: make(chan struct{}), - onCloseFunc: m.ln.Close, + onCloseFunc: refCount.Close, }, nil } @@ -197,21 +202,27 @@ func NewMultiPacketListener(addr string, onCloseFunc OnCloseFunc) MultiListener[ } func (m *multiPacketListener) Acquire() (net.PacketConn, error) { - m.mu.Lock() - defer m.mu.Unlock() - - var pc net.PacketConn - if m.pc == nil { - pc, err := net.ListenPacket("udp", m.addr) - if err != nil { - return nil, err + refCount, err := func() (RefCount[net.PacketConn], error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.pc == nil { + pc, err := net.ListenPacket("udp", m.addr) + if err != nil { + return nil, err + } + m.pc = NewRefCount(pc, m.onCloseFunc) } - m.pc = NewRefCount(pc, m.onCloseFunc) + return m.pc, nil + }() + if err != nil { + return nil, err } - pc = m.pc.Acquire() + + pc := refCount.Acquire() return &virtualPacketConn{ PacketConn: pc, - onCloseFunc: m.pc.Close, + onCloseFunc: refCount.Close, }, nil } From 80e5d491c6c5027b45e8b6e45e76f2428cea3ab0 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 7 Aug 2024 10:40:33 -0400 Subject: [PATCH 06/10] Fix concurrent access to `acceptCh` and `closeCh`. --- service/listeners.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/service/listeners.go b/service/listeners.go index 9814adca..1cb09d06 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -72,17 +72,23 @@ type acceptResponse struct { type OnCloseFunc func() error type virtualStreamListener struct { + mu sync.Mutex // Mutex to protect access to the channels addr net.Addr acceptCh <-chan acceptResponse closeCh chan struct{} + closed bool onCloseFunc OnCloseFunc } var _ StreamListener = (*virtualStreamListener)(nil) func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) { + sl.mu.Lock() + acceptCh := sl.acceptCh + sl.mu.Unlock() + select { - case acceptResponse, ok := <-sl.acceptCh: + case acceptResponse, ok := <-acceptCh: if !ok { return nil, net.ErrClosed } @@ -93,8 +99,16 @@ func (sl *virtualStreamListener) AcceptStream() (transport.StreamConn, error) { } func (sl *virtualStreamListener) Close() error { + sl.mu.Lock() + if sl.closed { + sl.mu.Unlock() + return nil + } + sl.closed = true sl.acceptCh = nil close(sl.closeCh) + sl.mu.Unlock() + if sl.onCloseFunc != nil { return sl.onCloseFunc() } From aa00f2efe1d19632e30c2cd27f7a018e92d5d4a4 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 7 Aug 2024 11:42:56 -0400 Subject: [PATCH 07/10] Use `/` in key instead of `-`. --- cmd/outline-ss-server/main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 55a11acc..9fd570f7 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -107,7 +107,7 @@ func (ls *listenerSet) ListenStream(addr string) (service.StreamListener, error) ls.listenersMu.Lock() defer ls.listenersMu.Unlock() - lnKey := "stream-" + addr + lnKey := "stream/" + addr if _, exists := ls.listenerCloseFuncs[lnKey]; exists { return nil, fmt.Errorf("stream listener for %s already exists", addr) } @@ -125,7 +125,7 @@ func (ls *listenerSet) ListenPacket(addr string) (net.PacketConn, error) { ls.listenersMu.Lock() defer ls.listenersMu.Unlock() - lnKey := "packet-" + addr + lnKey := "packet/" + addr if _, exists := ls.listenerCloseFuncs[lnKey]; exists { return nil, fmt.Errorf("packet listener for %s already exists", addr) } From e658b90573a79bd26cf19f468da47f73ca694a07 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 7 Aug 2024 11:51:19 -0400 Subject: [PATCH 08/10] Return error from stopping listeners. --- cmd/outline-ss-server/main.go | 34 ++++++++++++++++++++-------- cmd/outline-ss-server/server_test.go | 4 +++- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index 9fd570f7..aa8bb55e 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -61,7 +61,7 @@ func init() { } type SSServer struct { - stopConfig func() + stopConfig func() error lnManager service.ListenerManager natTimeout time.Duration m *outlineMetrics @@ -76,12 +76,14 @@ func (s *SSServer) loadConfig(filename string) error { // We hot swap the config by having the old and new listeners both live at // the same time. This means we create listeners for the new config first, // and then close the old ones after. - stopConfig, err := s.runConfig(*config) + sopConfig, err := s.runConfig(*config) if err != nil { return err } - go s.stopConfig() - s.stopConfig = stopConfig + if err := s.Stop(); err != nil { + return fmt.Errorf("unable to stop old config: %v", err) + } + s.stopConfig = sopConfig return nil } @@ -156,8 +158,9 @@ func (ls *listenerSet) Len() int { return len(ls.listenerCloseFuncs) } -func (s *SSServer) runConfig(config Config) (func(), error) { +func (s *SSServer) runConfig(config Config) (func() error, error) { startErrCh := make(chan error) + stopErrCh := make(chan error) stopCh := make(chan struct{}) go func() { @@ -165,7 +168,9 @@ func (s *SSServer) runConfig(config Config) (func(), error) { manager: s.lnManager, listenerCloseFuncs: make(map[string]func() error), } - defer lnSet.Close() // This closes all the listeners in the set. + defer func() { + stopErrCh <- lnSet.Close() + }() startErrCh <- func() error { portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry. @@ -216,24 +221,33 @@ func (s *SSServer) runConfig(config Config) (func(), error) { if err != nil { return nil, err } - return func() { + return func() error { logger.Infof("Stopping running config.") // TODO(sbruens): Actually wait for all handlers to be stopped, e.g. by // using a https://pkg.go.dev/sync#WaitGroup. stopCh <- struct{}{} + stopErr := <-stopErrCh + return stopErr }, nil } // Stop stops serving the current config. -func (s *SSServer) Stop() { - go s.stopConfig() +func (s *SSServer) Stop() error { + stopFunc := s.stopConfig + if stopFunc == nil { + return nil + } + if err := stopFunc(); err != nil { + logger.Errorf("Error stopping config: %v", err) + return err + } logger.Info("Stopped all listeners for running config") + return nil } // RunSSServer starts a shadowsocks server running, and returns the server or an error. func RunSSServer(filename string, natTimeout time.Duration, sm *outlineMetrics, replayHistory int) (*SSServer, error) { server := &SSServer{ - stopConfig: func() {}, lnManager: service.NewListenerManager(), natTimeout: natTimeout, m: sm, diff --git a/cmd/outline-ss-server/server_test.go b/cmd/outline-ss-server/server_test.go index 2ba0772e..0b7777b2 100644 --- a/cmd/outline-ss-server/server_test.go +++ b/cmd/outline-ss-server/server_test.go @@ -27,5 +27,7 @@ func TestRunSSServer(t *testing.T) { if err != nil { t.Fatalf("RunSSServer() error = %v", err) } - server.Stop() + if err := server.Stop(); err != nil { + t.Errorf("Error while stopping server: %v", err) + } } From fede4d8d7764e452951dc7f59a0a7bb38b828b41 Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 7 Aug 2024 13:18:06 -0400 Subject: [PATCH 09/10] Use channels to ensure `virtualPacketConn`s get closed. --- service/listeners.go | 63 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 5 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 1cb09d06..788d651f 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -64,13 +64,13 @@ func (t *TCPListener) Addr() net.Addr { return t.ln.Addr() } +type OnCloseFunc func() error + type acceptResponse struct { conn transport.StreamConn err error } -type OnCloseFunc func() error - type virtualStreamListener struct { mu sync.Mutex // Mutex to protect access to the channels addr net.Addr @@ -119,14 +119,52 @@ func (sl *virtualStreamListener) Addr() net.Addr { return sl.addr } +type packetResponse struct { + n int + addr net.Addr + err error + data []byte +} + type virtualPacketConn struct { net.PacketConn + mu sync.Mutex // Mutex to protect access to the channels + readCh <-chan packetResponse + closeCh chan struct{} + closed bool onCloseFunc OnCloseFunc } -func (spc *virtualPacketConn) Close() error { - if spc.onCloseFunc != nil { - return spc.onCloseFunc() +func (pc *virtualPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + pc.mu.Lock() + readCh := pc.readCh + pc.mu.Unlock() + + select { + case packetResponse, ok := <-readCh: + if !ok { + return 0, nil, net.ErrClosed + } + copy(p, packetResponse.data) + return packetResponse.n, packetResponse.addr, packetResponse.err + case <-pc.closeCh: + return 0, nil, net.ErrClosed + } +} + +func (pc *virtualPacketConn) Close() error { + pc.mu.Lock() + if pc.closed { + pc.mu.Unlock() + return nil + } + pc.closed = true + pc.readCh = nil + close(pc.closeCh) + pc.mu.Unlock() + + if pc.onCloseFunc != nil { + return pc.onCloseFunc() } return nil } @@ -204,6 +242,7 @@ type multiPacketListener struct { mu sync.Mutex addr string pc RefCount[net.PacketConn] + readCh chan packetResponse onCloseFunc OnCloseFunc } @@ -226,6 +265,18 @@ func (m *multiPacketListener) Acquire() (net.PacketConn, error) { return nil, err } m.pc = NewRefCount(pc, m.onCloseFunc) + m.readCh = make(chan packetResponse) + go func() { + for { + buffer := make([]byte, serverUDPBufferSize) + n, addr, err := pc.ReadFrom(buffer) + if err != nil { + close(m.readCh) + return + } + m.readCh <- packetResponse{n: n, addr: addr, err: err, data: buffer[:n]} + } + }() } return m.pc, nil }() @@ -236,6 +287,8 @@ func (m *multiPacketListener) Acquire() (net.PacketConn, error) { pc := refCount.Acquire() return &virtualPacketConn{ PacketConn: pc, + readCh: m.readCh, + closeCh: make(chan struct{}), onCloseFunc: refCount.Close, }, nil } From 4730d741237e5f697975d5f0f1ae26c3cebc0fce Mon Sep 17 00:00:00 2001 From: sbruens Date: Wed, 7 Aug 2024 16:21:10 -0400 Subject: [PATCH 10/10] Add more test cases for packet listeners. --- service/listeners_test.go | 61 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/service/listeners_test.go b/service/listeners_test.go index 0a840cff..d627ec1a 100644 --- a/service/listeners_test.go +++ b/service/listeners_test.go @@ -51,18 +51,17 @@ func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) { require.NoError(t, err) ln2, err := m.ListenStream("127.0.0.1:0") require.NoError(t, err) - // Close only the first listener. ln.Close() + done := make(chan struct{}) go func() { ln2.AcceptStream() done <- struct{}{} }() - err = writeTestPayload(ln2) - require.NoError(t, err) + require.NoError(t, err) <-done } @@ -82,8 +81,64 @@ func TestListenerManagerStreamListenerCreatesListenerOnDemand(t *testing.T) { done <- struct{}{} }() err = writeTestPayload(ln2) + + require.NoError(t, err) + <-done +} + +func TestListenerManagerPacketListenerEarlyClose(t *testing.T) { + m := NewListenerManager() + pc, err := m.ListenPacket("127.0.0.1:0") + require.NoError(t, err) + + pc.Close() + _, _, readErr := pc.ReadFrom(nil) + _, writeErr := pc.WriteTo(nil, &net.UDPAddr{}) + + require.ErrorIs(t, readErr, net.ErrClosed) + require.ErrorIs(t, writeErr, net.ErrClosed) +} + +func TestListenerManagerPacketListenerNotClosedIfStillInUse(t *testing.T) { + m := NewListenerManager() + pc, err := m.ListenPacket("127.0.0.1:0") + require.NoError(t, err) + pc2, err := m.ListenPacket("127.0.0.1:0") + require.NoError(t, err) + // Close only the first listener. + pc.Close() + + done := make(chan struct{}) + go func() { + _, _, readErr := pc2.ReadFrom(nil) + require.NoError(t, readErr) + done <- struct{}{} + }() + _, err = pc.WriteTo(nil, pc2.LocalAddr()) + + require.NoError(t, err) + <-done +} + +func TestListenerManagerPacketListenerCreatesListenerOnDemand(t *testing.T) { + m := NewListenerManager() + // Create a listener and immediately close it. + pc, err := m.ListenPacket("127.0.0.1:0") require.NoError(t, err) + pc.Close() + // Now create another listener on the same address. + pc2, err := m.ListenPacket("127.0.0.1:0") + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + _, _, readErr := pc2.ReadFrom(nil) + require.NoError(t, readErr) + done <- struct{}{} + }() + _, err = pc2.WriteTo(nil, pc2.LocalAddr()) + require.NoError(t, err) <-done }