From 6b11f4ffc4bd2c9d25500a8512c35da742211715 Mon Sep 17 00:00:00 2001 From: sbruens Date: Fri, 2 Aug 2024 11:51:35 -0400 Subject: [PATCH] Refactor create methods. --- service/listeners.go | 117 ++++++++++++++++++++++++------------------- 1 file changed, 66 insertions(+), 51 deletions(-) diff --git a/service/listeners.go b/service/listeners.go index 44fa7ec5..8f0882ba 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -187,86 +187,97 @@ func NewListenerManager() ListenerManager { } } -func (m *listenerManager) getOrCreate(key string, createFunc func() (Listener, error)) (RefCount[Listener], error) { +func (m *listenerManager) newStreamListener(network string, addr string) (Listener, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + ln, err := net.ListenTCP(network, tcpAddr) + if err != nil { + return nil, err + } + streamLn := &TCPListener{ln} + lnAddr := &listenAddr{ + ln: streamLn, + acceptCh: make(chan acceptResponse), + onCloseFunc: func() error { + m.delete(listenerKey(network, addr)) + return nil + }, + } + go func() { + for { + conn, err := streamLn.AcceptStream() + if errors.Is(err, net.ErrClosed) { + close(lnAddr.acceptCh) + return + } + lnAddr.acceptCh <- acceptResponse{conn, err} + } + }() + return lnAddr, nil +} + +func (m *listenerManager) newPacketListener(network string, addr string) (Listener, error) { + pc, err := net.ListenPacket(network, addr) + if err != nil { + return nil, err + } + return &listenAddr{ + ln: pc, + onCloseFunc: func() error { + m.delete(listenerKey(network, addr)) + return nil + }, + }, nil +} + +func (m *listenerManager) getListener(network string, addr string) (RefCount[Listener], error) { m.listenersMu.Lock() defer m.listenersMu.Unlock() - if lnRefCount, exists := m.listeners[key]; exists { + lnKey := listenerKey(network, addr) + if lnRefCount, exists := m.listeners[lnKey]; exists { return lnRefCount.Acquire(), nil } - ln, err := createFunc() + var ( + ln Listener + err error + ) + if network == "tcp" { + ln, err = m.newStreamListener(network, addr) + } else { + ln, err = m.newPacketListener(network, addr) + } if err != nil { return nil, err } lnRefCount := NewRefCount(ln) - m.listeners[key] = lnRefCount + m.listeners[lnKey] = lnRefCount return lnRefCount, nil } func (m *listenerManager) ListenStream(network string, addr string) (StreamListener, error) { - lnKey := network + "/" + addr - lnRefCount, err := m.getOrCreate(lnKey, func() (Listener, error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - ln, err := net.ListenTCP(network, tcpAddr) - if err != nil { - return nil, err - } - streamLn := &TCPListener{ln} - lnAddr := &listenAddr{ - ln: streamLn, - acceptCh: make(chan acceptResponse), - onCloseFunc: func() error { - m.delete(lnKey) - return nil - }, - } - go func() { - for { - conn, err := streamLn.AcceptStream() - if errors.Is(err, net.ErrClosed) { - close(lnAddr.acceptCh) - return - } - lnAddr.acceptCh <- acceptResponse{conn, err} - } - }() - return lnAddr, nil - }) + lnRefCount, err := m.getListener(network, addr) if err != nil { return nil, err } if lnAddr, ok := lnRefCount.Get().(canCreateStreamListener); ok { return lnAddr.NewStreamListener(lnRefCount.Close), nil } - return nil, fmt.Errorf("unable to create stream listener for %s", lnKey) + return nil, fmt.Errorf("unable to create stream listener for %s/%s", network, addr) } func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketConn, error) { - lnKey := network + "/" + addr - lnRefCount, err := m.getOrCreate(lnKey, func() (Listener, error) { - pc, err := net.ListenPacket(network, addr) - if err != nil { - return nil, err - } - return &listenAddr{ - ln: pc, - onCloseFunc: func() error { - m.delete(lnKey) - return nil - }, - }, nil - }) + lnRefCount, err := m.getListener(network, addr) if err != nil { return nil, err } if lnAddr, ok := lnRefCount.Get().(canCreatePacketListener); ok { return lnAddr.NewPacketListener(lnRefCount.Close), nil } - return nil, fmt.Errorf("unable to create packet listener for %s", lnKey) + return nil, fmt.Errorf("unable to create packet listener for %s/%s", network, addr) } func (m *listenerManager) delete(key string) { @@ -275,6 +286,10 @@ func (m *listenerManager) delete(key string) { m.listenersMu.Unlock() } +func listenerKey(network string, addr string) string { + return network + "/" + addr +} + // RefCount is an atomic reference counter that can be used to track a shared // [io.Closer] resource. type RefCount[T io.Closer] interface {