Skip to content

Commit

Permalink
Merge branch 'sbruens/shared-listeners' into sbruens/proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
sbruens committed Aug 2, 2024
2 parents f9432d2 + 6b11f4f commit 3e03394
Showing 1 changed file with 66 additions and 51 deletions.
117 changes: 66 additions & 51 deletions service/listeners.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,86 +187,97 @@ func NewListenerManager() ListenerManager {
}
}

func (m *listenerManager) getOrCreate(key string, createFunc func() (Listener, error)) (RefCount[Listener], error) {
func (m *listenerManager) newStreamListener(network string, addr string) (Listener, error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
ln, err := net.ListenTCP(network, tcpAddr)
if err != nil {
return nil, err
}
streamLn := &TCPListener{ln}
lnAddr := &listenAddr{
ln: streamLn,
acceptCh: make(chan acceptResponse),
onCloseFunc: func() error {
m.delete(listenerKey(network, addr))
return nil
},
}
go func() {
for {
conn, err := streamLn.AcceptStream()
if errors.Is(err, net.ErrClosed) {
close(lnAddr.acceptCh)
return
}
lnAddr.acceptCh <- acceptResponse{conn, err}
}
}()
return lnAddr, nil
}

func (m *listenerManager) newPacketListener(network string, addr string) (Listener, error) {
pc, err := net.ListenPacket(network, addr)
if err != nil {
return nil, err
}
return &listenAddr{
ln: pc,
onCloseFunc: func() error {
m.delete(listenerKey(network, addr))
return nil
},
}, nil
}

func (m *listenerManager) getListener(network string, addr string) (RefCount[Listener], error) {
m.listenersMu.Lock()
defer m.listenersMu.Unlock()

if lnRefCount, exists := m.listeners[key]; exists {
lnKey := listenerKey(network, addr)
if lnRefCount, exists := m.listeners[lnKey]; exists {
return lnRefCount.Acquire(), nil
}

ln, err := createFunc()
var (
ln Listener
err error
)
if network == "tcp" {
ln, err = m.newStreamListener(network, addr)
} else {
ln, err = m.newPacketListener(network, addr)
}
if err != nil {
return nil, err
}
lnRefCount := NewRefCount(ln)
m.listeners[key] = lnRefCount
m.listeners[lnKey] = lnRefCount
return lnRefCount, nil
}

func (m *listenerManager) ListenStream(network string, addr string) (StreamListener, error) {
lnKey := network + "/" + addr
lnRefCount, err := m.getOrCreate(lnKey, func() (Listener, error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
ln, err := net.ListenTCP(network, tcpAddr)
if err != nil {
return nil, err
}
streamLn := &TCPListener{ln}
lnAddr := &listenAddr{
ln: streamLn,
acceptCh: make(chan acceptResponse),
onCloseFunc: func() error {
m.delete(lnKey)
return nil
},
}
go func() {
for {
conn, err := streamLn.AcceptStream()
if errors.Is(err, net.ErrClosed) {
close(lnAddr.acceptCh)
return
}
lnAddr.acceptCh <- acceptResponse{conn, err}
}
}()
return lnAddr, nil
})
lnRefCount, err := m.getListener(network, addr)
if err != nil {
return nil, err
}
if lnAddr, ok := lnRefCount.Get().(canCreateStreamListener); ok {
return lnAddr.NewStreamListener(lnRefCount.Close), nil
}
return nil, fmt.Errorf("unable to create stream listener for %s", lnKey)
return nil, fmt.Errorf("unable to create stream listener for %s/%s", network, addr)
}

func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketConn, error) {
lnKey := network + "/" + addr
lnRefCount, err := m.getOrCreate(lnKey, func() (Listener, error) {
pc, err := net.ListenPacket(network, addr)
if err != nil {
return nil, err
}
return &listenAddr{
ln: pc,
onCloseFunc: func() error {
m.delete(lnKey)
return nil
},
}, nil
})
lnRefCount, err := m.getListener(network, addr)
if err != nil {
return nil, err
}
if lnAddr, ok := lnRefCount.Get().(canCreatePacketListener); ok {
return lnAddr.NewPacketListener(lnRefCount.Close), nil
}
return nil, fmt.Errorf("unable to create packet listener for %s", lnKey)
return nil, fmt.Errorf("unable to create packet listener for %s/%s", network, addr)
}

func (m *listenerManager) delete(key string) {
Expand All @@ -275,6 +286,10 @@ func (m *listenerManager) delete(key string) {
m.listenersMu.Unlock()
}

func listenerKey(network string, addr string) string {
return network + "/" + addr
}

// RefCount is an atomic reference counter that can be used to track a shared
// [io.Closer] resource.
type RefCount[T io.Closer] interface {
Expand Down

0 comments on commit 3e03394

Please sign in to comment.