Skip to content

Commit

Permalink
Add test scenarios for client addr.
Browse files Browse the repository at this point in the history
  • Loading branch information
sbruens committed Aug 5, 2024
1 parent 827be3c commit 2622b5e
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 11 deletions.
7 changes: 2 additions & 5 deletions service/listeners.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@ type clientStreamConn struct {
}

func (c *clientStreamConn) ClientAddr() net.Addr {
if c.clientAddr != nil {
return c.clientAddr
}
return c.StreamConn.RemoteAddr()
return c.clientAddr
}

// StreamListener is a network listener for stream-oriented protocols that
Expand Down Expand Up @@ -124,7 +121,7 @@ func (sl *virtualStreamListener) Addr() net.Addr {
return sl.listener.Addr()
}

// ProxyListener wraps a [StreamListener] and fetches the source of the connection from the PROXY
// ProxyStreamListener wraps a [StreamListener] and fetches the source of the connection from the PROXY
// protocol header string. See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt.
type ProxyStreamListener struct {
StreamListener
Expand Down
87 changes: 82 additions & 5 deletions service/listeners_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,89 @@ import (
"net"
"testing"

"github.com/pires/go-proxyproto"
"github.com/stretchr/testify/require"
)

func TestDirectListenerSetsRemoteAddrAsClientAddr(t *testing.T) {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
require.NoError(t, err)

go func() {
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoErrorf(t, err, "Failed to dial %v: %v", listener.Addr(), err)
conn.Write(makeTestPayload(50))
conn.Close()
}()

ln := &TCPListener{listener}
conn, err := ln.AcceptStream()
require.NoError(t, err)
require.Equal(t, conn.RemoteAddr(), conn.ClientAddr())
}

func TestProxyProtocolListenerParsesSourceAddressAsClientAddr(t *testing.T) {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
require.NoError(t, err)

sourceAddr := &net.TCPAddr{
IP: net.ParseIP("10.1.1.1"),
Port: 1000,
}
go func() {
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoErrorf(t, err, "Failed to dial %v: %v", listener.Addr(), err)
header := &proxyproto.Header{
Version: 2,
Command: proxyproto.PROXY,
TransportProtocol: proxyproto.TCPv4,
SourceAddr: sourceAddr,
DestinationAddr: conn.RemoteAddr(),
}
header.WriteTo(conn)
conn.Write(makeTestPayload(50))
conn.Close()
}()

ln := &ProxyStreamListener{StreamListener: &TCPListener{listener}}
conn, err := ln.AcceptStream()
require.NoError(t, err)
require.True(t, sourceAddr.IP.Equal(conn.ClientAddr().(*net.TCPAddr).IP))
require.Equal(t, sourceAddr.Port, conn.ClientAddr().(*net.TCPAddr).Port)
}

func TestProxyProtocolListenerUsesRemoteAddrAsClientAddrIfLocalHeader(t *testing.T) {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
require.NoError(t, err)

go func() {
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoErrorf(t, err, "Failed to dial %v: %v", listener.Addr(), err)

header := &proxyproto.Header{
Version: 2,
Command: proxyproto.LOCAL,
TransportProtocol: proxyproto.UNSPEC,
SourceAddr: &net.TCPAddr{
IP: net.ParseIP("10.1.1.1"),
Port: 1000,
},
DestinationAddr: conn.RemoteAddr(),
}
header.WriteTo(conn)
conn.Write(makeTestPayload(50))
conn.Close()
}()

ln := &ProxyStreamListener{StreamListener: &TCPListener{listener}}
conn, err := ln.AcceptStream()
require.NoError(t, err)
require.Equal(t, conn.RemoteAddr(), conn.ClientAddr())
}

func TestListenerManagerStreamListenerEarlyClose(t *testing.T) {
m := NewListenerManager()
ln, err := m.ListenStream("tcp", "127.0.0.1:0")
ln, err := m.ListenStream("tcp", "127.0.0.1:0", false)
require.NoError(t, err)

ln.Close()
Expand All @@ -47,9 +124,9 @@ func writeTestPayload(ln StreamListener) error {

func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) {
m := NewListenerManager()
ln, err := m.ListenStream("tcp", "127.0.0.1:0")
ln, err := m.ListenStream("tcp", "127.0.0.1:0", false)
require.NoError(t, err)
ln2, err := m.ListenStream("tcp", "127.0.0.1:0")
ln2, err := m.ListenStream("tcp", "127.0.0.1:0", false)
require.NoError(t, err)

// Close only the first listener.
Expand All @@ -69,11 +146,11 @@ func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) {
func TestListenerManagerStreamListenerCreatesListenerOnDemand(t *testing.T) {
m := NewListenerManager()
// Create a listener and immediately close it.
ln, err := m.ListenStream("tcp", "127.0.0.1:0")
ln, err := m.ListenStream("tcp", "127.0.0.1:0", false)
require.NoError(t, err)
ln.Close()
// Now create another listener on the same address.
ln2, err := m.ListenStream("tcp", "127.0.0.1:0")
ln2, err := m.ListenStream("tcp", "127.0.0.1:0", false)
require.NoError(t, err)

done := make(chan struct{})
Expand Down
2 changes: 1 addition & 1 deletion service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ type StreamAcceptFunc func() (ClientStreamConn, error)
func WrapStreamAcceptFunc[T transport.StreamConn](f func() (T, error)) StreamAcceptFunc {
return func() (ClientStreamConn, error) {
c, err := f()
return &clientStreamConn{StreamConn: c}, err
return &clientStreamConn{StreamConn: c, clientAddr: c.RemoteAddr()}, err
}
}

Expand Down

0 comments on commit 2622b5e

Please sign in to comment.