diff --git a/service/listeners.go b/service/listeners.go index 4037771e..02258ae1 100644 --- a/service/listeners.go +++ b/service/listeners.go @@ -45,10 +45,7 @@ type clientStreamConn struct { } func (c *clientStreamConn) ClientAddr() net.Addr { - if c.clientAddr != nil { - return c.clientAddr - } - return c.StreamConn.RemoteAddr() + return c.clientAddr } // StreamListener is a network listener for stream-oriented protocols that @@ -124,7 +121,7 @@ func (sl *virtualStreamListener) Addr() net.Addr { return sl.listener.Addr() } -// ProxyListener wraps a [StreamListener] and fetches the source of the connection from the PROXY +// ProxyStreamListener wraps a [StreamListener] and fetches the source of the connection from the PROXY // protocol header string. See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt. type ProxyStreamListener struct { StreamListener diff --git a/service/listeners_test.go b/service/listeners_test.go index 384126ad..e4e8c7c6 100644 --- a/service/listeners_test.go +++ b/service/listeners_test.go @@ -19,12 +19,89 @@ import ( "net" "testing" + "github.com/pires/go-proxyproto" "github.com/stretchr/testify/require" ) +func TestDirectListenerSetsRemoteAddrAsClientAddr(t *testing.T) { + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + + go func() { + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoErrorf(t, err, "Failed to dial %v: %v", listener.Addr(), err) + conn.Write(makeTestPayload(50)) + conn.Close() + }() + + ln := &TCPListener{listener} + conn, err := ln.AcceptStream() + require.NoError(t, err) + require.Equal(t, conn.RemoteAddr(), conn.ClientAddr()) +} + +func TestProxyProtocolListenerParsesSourceAddressAsClientAddr(t *testing.T) { + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + + sourceAddr := &net.TCPAddr{ + IP: net.ParseIP("10.1.1.1"), + Port: 1000, + } + go func() { + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoErrorf(t, err, "Failed to dial %v: %v", listener.Addr(), err) + header := &proxyproto.Header{ + Version: 2, + Command: proxyproto.PROXY, + TransportProtocol: proxyproto.TCPv4, + SourceAddr: sourceAddr, + DestinationAddr: conn.RemoteAddr(), + } + header.WriteTo(conn) + conn.Write(makeTestPayload(50)) + conn.Close() + }() + + ln := &ProxyStreamListener{StreamListener: &TCPListener{listener}} + conn, err := ln.AcceptStream() + require.NoError(t, err) + require.True(t, sourceAddr.IP.Equal(conn.ClientAddr().(*net.TCPAddr).IP)) + require.Equal(t, sourceAddr.Port, conn.ClientAddr().(*net.TCPAddr).Port) +} + +func TestProxyProtocolListenerUsesRemoteAddrAsClientAddrIfLocalHeader(t *testing.T) { + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + + go func() { + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoErrorf(t, err, "Failed to dial %v: %v", listener.Addr(), err) + + header := &proxyproto.Header{ + Version: 2, + Command: proxyproto.LOCAL, + TransportProtocol: proxyproto.UNSPEC, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP("10.1.1.1"), + Port: 1000, + }, + DestinationAddr: conn.RemoteAddr(), + } + header.WriteTo(conn) + conn.Write(makeTestPayload(50)) + conn.Close() + }() + + ln := &ProxyStreamListener{StreamListener: &TCPListener{listener}} + conn, err := ln.AcceptStream() + require.NoError(t, err) + require.Equal(t, conn.RemoteAddr(), conn.ClientAddr()) +} + func TestListenerManagerStreamListenerEarlyClose(t *testing.T) { m := NewListenerManager() - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("tcp", "127.0.0.1:0", false) require.NoError(t, err) ln.Close() @@ -47,9 +124,9 @@ func writeTestPayload(ln StreamListener) error { func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) { m := NewListenerManager() - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("tcp", "127.0.0.1:0", false) require.NoError(t, err) - ln2, err := m.ListenStream("tcp", "127.0.0.1:0") + ln2, err := m.ListenStream("tcp", "127.0.0.1:0", false) require.NoError(t, err) // Close only the first listener. @@ -69,11 +146,11 @@ func TestListenerManagerStreamListenerNotClosedIfStillInUse(t *testing.T) { func TestListenerManagerStreamListenerCreatesListenerOnDemand(t *testing.T) { m := NewListenerManager() // Create a listener and immediately close it. - ln, err := m.ListenStream("tcp", "127.0.0.1:0") + ln, err := m.ListenStream("tcp", "127.0.0.1:0", false) require.NoError(t, err) ln.Close() // Now create another listener on the same address. - ln2, err := m.ListenStream("tcp", "127.0.0.1:0") + ln2, err := m.ListenStream("tcp", "127.0.0.1:0", false) require.NoError(t, err) done := make(chan struct{}) diff --git a/service/tcp.go b/service/tcp.go index 054cc01e..b4188fa0 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -215,7 +215,7 @@ type StreamAcceptFunc func() (ClientStreamConn, error) func WrapStreamAcceptFunc[T transport.StreamConn](f func() (T, error)) StreamAcceptFunc { return func() (ClientStreamConn, error) { c, err := f() - return &clientStreamConn{StreamConn: c}, err + return &clientStreamConn{StreamConn: c, clientAddr: c.RemoteAddr()}, err } }