diff --git a/tcpuid_linux.go b/tcpuid_linux.go index 7aaa718..112ac9f 100644 --- a/tcpuid_linux.go +++ b/tcpuid_linux.go @@ -44,6 +44,14 @@ func TCPClientUIDSupported() bool { // TCPClientUID obtains UID of client process that created // TCP connection over the loopback interface func TCPClientUID(client, server *net.TCPAddr) (int, error) { + // Obtain protocol family. Check for mismatch. + clientIs4 := client.IP.To4() != nil + serverIs4 := server.IP.To4() != nil + + if clientIs4 != serverIs4 { + return -1, fmt.Errorf("TCPClientUID: IP4/IP6 mismatchh") + } + // Open NETLINK_SOCK_DIAG socket sock, err := sockDiagOpen() if err != nil { @@ -59,7 +67,20 @@ func TCPClientUID(client, server *net.TCPAddr) (int, error) { rq.hdr.nlmsg_type = C.uint16_t(C.SOCK_DIAG_BY_FAMILY) rq.hdr.nlmsg_flags = C.uint16_t(C.NLM_F_REQUEST) - rq.data.sdiag_family = C.AF_INET6 + if clientIs4 { + rq.data.sdiag_family = C.AF_INET + copy((*[16]byte)(unsafe.Pointer(&rq.data.id.idiag_src))[:], + client.IP.To4()) + copy((*[16]byte)(unsafe.Pointer(&rq.data.id.idiag_dst))[:], + server.IP.To4()) + } else { + rq.data.sdiag_family = C.AF_INET6 + copy((*[16]byte)(unsafe.Pointer(&rq.data.id.idiag_src))[:], + client.IP.To16()) + copy((*[16]byte)(unsafe.Pointer(&rq.data.id.idiag_dst))[:], + server.IP.To16()) + } + rq.data.sdiag_protocol = C.IPPROTO_TCP rq.data.idiag_states = 1 << C.TCP_ESTABLISHED rq.data.id.idiag_sport = C.uint16_t(toBE16((uint16(client.Port)))) @@ -67,11 +88,6 @@ func TCPClientUID(client, server *net.TCPAddr) (int, error) { rq.data.id.idiag_cookie[0] = C.INET_DIAG_NOCOOKIE rq.data.id.idiag_cookie[1] = C.INET_DIAG_NOCOOKIE - copy((*[16]byte)(unsafe.Pointer(&rq.data.id.idiag_src))[:], - client.IP.To16()) - copy((*[16]byte)(unsafe.Pointer(&rq.data.id.idiag_dst))[:], - server.IP.To16()) - // Send request rqData := (*[unsafe.Sizeof(rq)]byte)(unsafe.Pointer(&rq)) rqAddr := &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK} diff --git a/tcpuid_test.go b/tcpuid_test.go index e6db60b..d37cb61 100644 --- a/tcpuid_test.go +++ b/tcpuid_test.go @@ -16,25 +16,52 @@ import ( // doTestTCPClientUID performs TCPClientUID for the specified // network and loopback address -func doTestTCPClientUID(t *testing.T, network, loopback string) { +func doTestTCPClientUID(t *testing.T, ip4 bool) { // Do nothing if TCPClientUID is not supported by the platform if !TCPClientUIDSupported() { return } - // Log local addresses + // Log local addresses. Check that we have appropriate + // address family support, configured in the system. + var haveIP4, haveIP6 bool + if ift, err := net.Interfaces(); err == nil { for _, ifi := range ift { if addrs, err := ifi.Addrs(); err == nil { t.Logf("%s:", ifi.Name) for _, addr := range addrs { t.Logf(" %s", addr) + + if ipnet, ok := addr.(*net.IPNet); ok { + if ipnet.IP.To4() != nil { + haveIP4 = true + } else { + haveIP6 = true + } + } } } } } + // Skip incompatible address families + if ip4 && !haveIP4 { + return + } + + if !ip4 && !haveIP6 { + return + } + // Create loopback listener -- it gives us a port + network := "tcp4" + loopback := "127.0.0.1" + if !ip4 { + loopback = "[::1]" + network = "tcp6" + } + l, err := net.Listen(network, loopback+":") if err != nil { t.Fatalf("net.Listen(%q,%q): %s", network, loopback+":", err) @@ -76,10 +103,10 @@ func doTestTCPClientUID(t *testing.T, network, loopback string) { // TestTCPClientUIDIp4 performs TCPClientUID test for IPv4 func TestTCPClientUIDIp4(t *testing.T) { - doTestTCPClientUID(t, "tcp", "127.0.0.1") + doTestTCPClientUID(t, true) } // TestTCPClientUIDIp6 performs TCPClientUID test for IPv6 func TestTCPClientUIDIp6(t *testing.T) { - doTestTCPClientUID(t, "tcp6", "[::1]") + doTestTCPClientUID(t, false) }