diff --git a/conn/cmsg.go b/conn/cmsg.go new file mode 100644 index 0000000..a48fc84 --- /dev/null +++ b/conn/cmsg.go @@ -0,0 +1,28 @@ +package conn + +import "net/netip" + +// SocketControlMessageBufferSize specifies the buffer size for receiving socket control messages. +const SocketControlMessageBufferSize = socketControlMessageBufferSize + +// SocketControlMessage contains information that can be parsed from or put into socket control messages. +type SocketControlMessage struct { + // PktinfoAddr is the IP address of the network interface the packet was received from. + PktinfoAddr netip.Addr + + // PktinfoIfindex is the index of the network interface the packet was received from. + PktinfoIfindex uint32 + + // SegmentSize is the UDP GRO/GSO segment size. + SegmentSize uint32 +} + +// ParseSocketControlMessage parses a sequence of socket control messages and returns the parsed information. +func ParseSocketControlMessage(cmsg []byte) (m SocketControlMessage, err error) { + return parseSocketControlMessage(cmsg) +} + +// AppendTo appends the socket control message to the buffer. +func (m SocketControlMessage) AppendTo(b []byte) []byte { + return m.appendTo(b) +} diff --git a/conn/cmsg_darwin.go b/conn/cmsg_darwin.go new file mode 100644 index 0000000..52151d7 --- /dev/null +++ b/conn/cmsg_darwin.go @@ -0,0 +1,93 @@ +package conn + +import ( + "fmt" + "net/netip" + "unsafe" + + "github.com/database64128/swgp-go/slicehelper" + "golang.org/x/sys/unix" +) + +const socketControlMessageBufferSize = unix.SizeofCmsghdr + alignedSizeofInet6Pktinfo + +const cmsgAlignTo = 4 + +func cmsgAlign(n uint32) uint32 { + return (n + cmsgAlignTo - 1) & ^uint32(cmsgAlignTo-1) +} + +func parseSocketControlMessage(cmsg []byte) (m SocketControlMessage, err error) { + for len(cmsg) >= unix.SizeofCmsghdr { + cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(cmsg))) + msgSize := cmsgAlign(cmsghdr.Len) + if cmsghdr.Len < unix.SizeofCmsghdr || int(msgSize) > len(cmsg) { + return m, fmt.Errorf("invalid control message length %d", cmsghdr.Len) + } + + switch { + case cmsghdr.Level == unix.IPPROTO_IP && cmsghdr.Type == unix.IP_PKTINFO: + if len(cmsg) < unix.SizeofCmsghdr+unix.SizeofInet4Pktinfo { + return m, fmt.Errorf("invalid IP_PKTINFO control message length %d", cmsghdr.Len) + } + var pktinfo unix.Inet4Pktinfo + _ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet4Pktinfo), cmsg[unix.SizeofCmsghdr:]) + m.PktinfoAddr = netip.AddrFrom4(pktinfo.Spec_dst) + m.PktinfoIfindex = pktinfo.Ifindex + + case cmsghdr.Level == unix.IPPROTO_IPV6 && cmsghdr.Type == unix.IPV6_PKTINFO: + if len(cmsg) < unix.SizeofCmsghdr+unix.SizeofInet6Pktinfo { + return m, fmt.Errorf("invalid IPV6_PKTINFO control message length %d", cmsghdr.Len) + } + var pktinfo unix.Inet6Pktinfo + _ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet6Pktinfo), cmsg[unix.SizeofCmsghdr:]) + m.PktinfoAddr = netip.AddrFrom16(pktinfo.Addr) + m.PktinfoIfindex = pktinfo.Ifindex + } + + cmsg = cmsg[msgSize:] + } + + return m, nil +} + +const ( + alignedSizeofInet4Pktinfo = (unix.SizeofInet4Pktinfo + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1) + alignedSizeofInet6Pktinfo = (unix.SizeofInet6Pktinfo + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1) +) + +func (m SocketControlMessage) appendTo(b []byte) []byte { + switch { + case m.PktinfoAddr.Is4(): + var msgBuf []byte + b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofInet4Pktinfo) + cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) + *cmsghdr = unix.Cmsghdr{ + Len: unix.SizeofCmsghdr + unix.SizeofInet4Pktinfo, + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + } + pktinfo := unix.Inet4Pktinfo{ + Ifindex: m.PktinfoIfindex, + Spec_dst: m.PktinfoAddr.As4(), + } + _ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet4Pktinfo)) + + case m.PktinfoAddr.Is6(): + var msgBuf []byte + b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofInet6Pktinfo) + cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) + *cmsghdr = unix.Cmsghdr{ + Len: unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo, + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + } + pktinfo := unix.Inet6Pktinfo{ + Addr: m.PktinfoAddr.As16(), + Ifindex: m.PktinfoIfindex, + } + _ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet6Pktinfo)) + } + + return b +} diff --git a/conn/cmsg_linux.go b/conn/cmsg_linux.go new file mode 100644 index 0000000..98dd56e --- /dev/null +++ b/conn/cmsg_linux.go @@ -0,0 +1,116 @@ +package conn + +import ( + "fmt" + "net/netip" + "unsafe" + + "github.com/database64128/swgp-go/slicehelper" + "golang.org/x/sys/unix" +) + +const socketControlMessageBufferSize = unix.SizeofCmsghdr + alignedSizeofInet6Pktinfo + + unix.SizeofCmsghdr + alignedSizeofGROSegmentSize + +const sizeofGROSegmentSize = int(unsafe.Sizeof(uint16(0))) + +func cmsgAlign(n uint64) uint64 { + return (n + unix.SizeofPtr - 1) & ^uint64(unix.SizeofPtr-1) +} + +func parseSocketControlMessage(cmsg []byte) (m SocketControlMessage, err error) { + for len(cmsg) >= unix.SizeofCmsghdr { + cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(cmsg))) + msgSize := cmsgAlign(cmsghdr.Len) + if cmsghdr.Len < unix.SizeofCmsghdr || int(msgSize) > len(cmsg) { + return m, fmt.Errorf("invalid control message length %d", cmsghdr.Len) + } + + switch { + case cmsghdr.Level == unix.IPPROTO_IP && cmsghdr.Type == unix.IP_PKTINFO: + if len(cmsg) < unix.SizeofCmsghdr+unix.SizeofInet4Pktinfo { + return m, fmt.Errorf("invalid IP_PKTINFO control message length %d", cmsghdr.Len) + } + var pktinfo unix.Inet4Pktinfo + _ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet4Pktinfo), cmsg[unix.SizeofCmsghdr:]) + m.PktinfoAddr = netip.AddrFrom4(pktinfo.Spec_dst) + m.PktinfoIfindex = uint32(pktinfo.Ifindex) + + case cmsghdr.Level == unix.IPPROTO_IPV6 && cmsghdr.Type == unix.IPV6_PKTINFO: + if len(cmsg) < unix.SizeofCmsghdr+unix.SizeofInet6Pktinfo { + return m, fmt.Errorf("invalid IPV6_PKTINFO control message length %d", cmsghdr.Len) + } + var pktinfo unix.Inet6Pktinfo + _ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet6Pktinfo), cmsg[unix.SizeofCmsghdr:]) + m.PktinfoAddr = netip.AddrFrom16(pktinfo.Addr) + m.PktinfoIfindex = pktinfo.Ifindex + + case cmsghdr.Level == unix.IPPROTO_UDP && cmsghdr.Type == unix.UDP_GRO: + if len(cmsg) < unix.SizeofCmsghdr+sizeofGROSegmentSize { + return m, fmt.Errorf("invalid UDP_GRO control message length %d", cmsghdr.Len) + } + var segmentSize uint16 + _ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&segmentSize)), sizeofGROSegmentSize), cmsg[unix.SizeofCmsghdr:]) + m.SegmentSize = uint32(segmentSize) + } + + cmsg = cmsg[msgSize:] + } + + return m, nil +} + +const ( + alignedSizeofInet4Pktinfo = (unix.SizeofInet4Pktinfo + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1) + alignedSizeofInet6Pktinfo = (unix.SizeofInet6Pktinfo + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1) + alignedSizeofGROSegmentSize = (sizeofGROSegmentSize + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1) +) + +func (m SocketControlMessage) appendTo(b []byte) []byte { + switch { + case m.PktinfoAddr.Is4(): + var msgBuf []byte + b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofInet4Pktinfo) + cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) + *cmsghdr = unix.Cmsghdr{ + Len: unix.SizeofCmsghdr + unix.SizeofInet4Pktinfo, + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + } + pktinfo := unix.Inet4Pktinfo{ + Ifindex: int32(m.PktinfoIfindex), + Spec_dst: m.PktinfoAddr.As4(), + } + _ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet4Pktinfo)) + + case m.PktinfoAddr.Is6(): + var msgBuf []byte + b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofInet6Pktinfo) + cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) + *cmsghdr = unix.Cmsghdr{ + Len: unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo, + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + } + pktinfo := unix.Inet6Pktinfo{ + Addr: m.PktinfoAddr.As16(), + Ifindex: m.PktinfoIfindex, + } + _ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet6Pktinfo)) + } + + if m.SegmentSize > 0 { + var msgBuf []byte + b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofGROSegmentSize) + cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) + *cmsghdr = unix.Cmsghdr{ + Len: unix.SizeofCmsghdr + uint64(sizeofGROSegmentSize), + Level: unix.IPPROTO_UDP, + Type: unix.UDP_GRO, + } + segmentSize := uint16(m.SegmentSize) + _ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&segmentSize)), sizeofGROSegmentSize)) + } + + return b +} diff --git a/conn/cmsg_stub.go b/conn/cmsg_stub.go new file mode 100644 index 0000000..e2bf9b5 --- /dev/null +++ b/conn/cmsg_stub.go @@ -0,0 +1,13 @@ +//go:build !darwin && !linux && !windows + +package conn + +const socketControlMessageBufferSize = 0 + +func parseSocketControlMessage(_ []byte) (SocketControlMessage, error) { + return SocketControlMessage{}, nil +} + +func (SocketControlMessage) appendTo(b []byte) []byte { + return b +} diff --git a/conn/cmsg_windows.go b/conn/cmsg_windows.go new file mode 100644 index 0000000..c56eb49 --- /dev/null +++ b/conn/cmsg_windows.go @@ -0,0 +1,139 @@ +package conn + +import ( + "fmt" + "net/netip" + "unsafe" + + "github.com/database64128/swgp-go/slicehelper" + "golang.org/x/sys/windows" +) + +const socketControlMessageBufferSize = sizeofCmsghdr + alignedSizeofInet6Pktinfo + + sizeofCmsghdr + alignedSizeofCoalescedInfo + +const ( + sizeofPtr = int(unsafe.Sizeof(uintptr(0))) + sizeofCmsghdr = int(unsafe.Sizeof(Cmsghdr{})) + sizeofInet4Pktinfo = int(unsafe.Sizeof(Inet4Pktinfo{})) + sizeofInet6Pktinfo = int(unsafe.Sizeof(Inet6Pktinfo{})) + sizeofCoalescedInfo = int(unsafe.Sizeof(uint32(0))) +) + +// Structure CMSGHDR from ws2def.h +type Cmsghdr struct { + Len uintptr + Level int32 + Type int32 +} + +// Structure IN_PKTINFO from ws2ipdef.h +type Inet4Pktinfo struct { + Addr [4]byte + Ifindex uint32 +} + +// Structure IN6_PKTINFO from ws2ipdef.h +type Inet6Pktinfo struct { + Addr [16]byte + Ifindex uint32 +} + +func cmsgAlign(n int) int { + return (n + sizeofPtr - 1) & ^(sizeofPtr - 1) +} + +func parseSocketControlMessage(cmsg []byte) (m SocketControlMessage, err error) { + for len(cmsg) >= sizeofCmsghdr { + cmsghdr := (*Cmsghdr)(unsafe.Pointer(unsafe.SliceData(cmsg))) + msgLen := int(cmsghdr.Len) + msgSize := cmsgAlign(msgLen) + if msgLen < sizeofCmsghdr || msgSize > len(cmsg) { + return m, fmt.Errorf("invalid control message length %d", cmsghdr.Len) + } + + switch { + case cmsghdr.Level == windows.IPPROTO_IP && cmsghdr.Type == windows.IP_PKTINFO: + if len(cmsg) < sizeofCmsghdr+sizeofInet4Pktinfo { + return m, fmt.Errorf("invalid IP_PKTINFO control message length %d", cmsghdr.Len) + } + var pktinfo Inet4Pktinfo + _ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), sizeofInet4Pktinfo), cmsg[sizeofCmsghdr:]) + m.PktinfoAddr = netip.AddrFrom4(pktinfo.Addr) + m.PktinfoIfindex = pktinfo.Ifindex + + case cmsghdr.Level == windows.IPPROTO_IPV6 && cmsghdr.Type == windows.IPV6_PKTINFO: + if len(cmsg) < sizeofCmsghdr+sizeofInet6Pktinfo { + return m, fmt.Errorf("invalid IPV6_PKTINFO control message length %d", cmsghdr.Len) + } + var pktinfo Inet6Pktinfo + _ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), sizeofInet6Pktinfo), cmsg[sizeofCmsghdr:]) + m.PktinfoAddr = netip.AddrFrom16(pktinfo.Addr) + m.PktinfoIfindex = pktinfo.Ifindex + + case cmsghdr.Level == windows.IPPROTO_UDP && cmsghdr.Type == windows.UDP_COALESCED_INFO: + if len(cmsg) < sizeofCmsghdr+sizeofCoalescedInfo { + return m, fmt.Errorf("invalid UDP_COALESCED_INFO control message length %d", cmsghdr.Len) + } + _ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&m.SegmentSize)), sizeofCoalescedInfo), cmsg[sizeofCmsghdr:]) + } + + cmsg = cmsg[msgSize:] + } + + return m, nil +} + +const ( + alignedSizeofInet4Pktinfo = (sizeofInet4Pktinfo + sizeofPtr - 1) & ^(sizeofPtr - 1) + alignedSizeofInet6Pktinfo = (sizeofInet6Pktinfo + sizeofPtr - 1) & ^(sizeofPtr - 1) + alignedSizeofCoalescedInfo = (sizeofCoalescedInfo + sizeofPtr - 1) & ^(sizeofPtr - 1) +) + +func (m SocketControlMessage) appendTo(b []byte) []byte { + switch { + case m.PktinfoAddr.Is4(): + var msgBuf []byte + b, msgBuf = slicehelper.Extend(b, sizeofCmsghdr+alignedSizeofInet4Pktinfo) + cmsghdr := (*Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) + *cmsghdr = Cmsghdr{ + Len: uintptr(sizeofCmsghdr + sizeofInet4Pktinfo), + Level: windows.IPPROTO_IP, + Type: windows.IP_PKTINFO, + } + pktinfo := Inet4Pktinfo{ + Addr: m.PktinfoAddr.As4(), + Ifindex: m.PktinfoIfindex, + } + _ = copy(msgBuf[sizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), sizeofInet4Pktinfo)) + + case m.PktinfoAddr.Is6(): + var msgBuf []byte + b, msgBuf = slicehelper.Extend(b, sizeofCmsghdr+alignedSizeofInet6Pktinfo) + cmsghdr := (*Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) + *cmsghdr = Cmsghdr{ + Len: uintptr(sizeofCmsghdr + sizeofInet6Pktinfo), + Level: windows.IPPROTO_IPV6, + Type: windows.IPV6_PKTINFO, + } + pktinfo := Inet6Pktinfo{ + Addr: m.PktinfoAddr.As16(), + Ifindex: m.PktinfoIfindex, + } + _ = copy(msgBuf[sizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), sizeofInet6Pktinfo)) + } + + if m.SegmentSize > 0 { + var msgBuf []byte + b, msgBuf = slicehelper.Extend(b, sizeofCmsghdr+alignedSizeofCoalescedInfo) + cmsghdr := (*Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) + *cmsghdr = Cmsghdr{ + Len: uintptr(sizeofCmsghdr + sizeofCoalescedInfo), + Level: windows.IPPROTO_UDP, + Type: windows.UDP_COALESCED_INFO, + } + _ = copy(msgBuf[sizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&m.SegmentSize)), sizeofCoalescedInfo)) + } + + return b +} diff --git a/conn/conn_darwinlinux.go b/conn/conn_darwinlinux.go deleted file mode 100644 index f9b011f..0000000 --- a/conn/conn_darwinlinux.go +++ /dev/null @@ -1,40 +0,0 @@ -//go:build darwin || linux - -package conn - -import ( - "fmt" - "net/netip" - "unsafe" - - "golang.org/x/sys/unix" -) - -// SocketControlMessageBufferSize specifies the buffer size for receiving socket control messages. -const SocketControlMessageBufferSize = unix.SizeofCmsghdr + (unix.SizeofInet6Pktinfo+unix.SizeofPtr-1) & ^(unix.SizeofPtr-1) - -// ParsePktinfoCmsg parses a single socket control message of type IP_PKTINFO or IPV6_PKTINFO, -// and returns the IP address and index of the network interface the packet was received from, -// or an error. -// -// This function is only implemented for Linux, macOS and Windows. On other platforms, this is a no-op. -func ParsePktinfoCmsg(cmsg []byte) (netip.Addr, uint32, error) { - if len(cmsg) < unix.SizeofCmsghdr { - return netip.Addr{}, 0, fmt.Errorf("control message length %d is shorter than cmsghdr length", len(cmsg)) - } - - cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(&cmsg[0])) - - switch { - case cmsghdr.Level == unix.IPPROTO_IP && cmsghdr.Type == unix.IP_PKTINFO && len(cmsg) >= unix.SizeofCmsghdr+unix.SizeofInet4Pktinfo: - pktinfo := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg[unix.SizeofCmsghdr])) - return netip.AddrFrom4(pktinfo.Spec_dst), uint32(pktinfo.Ifindex), nil - - case cmsghdr.Level == unix.IPPROTO_IPV6 && cmsghdr.Type == unix.IPV6_PKTINFO && len(cmsg) >= unix.SizeofCmsghdr+unix.SizeofInet6Pktinfo: - pktinfo := (*unix.Inet6Pktinfo)(unsafe.Pointer(&cmsg[unix.SizeofCmsghdr])) - return netip.AddrFrom16(pktinfo.Addr), pktinfo.Ifindex, nil - - default: - return netip.Addr{}, 0, fmt.Errorf("unknown control message level %d type %d", cmsghdr.Level, cmsghdr.Type) - } -} diff --git a/conn/conn_notdarwinlinuxwindows.go b/conn/conn_notdarwinlinuxwindows.go deleted file mode 100644 index bcfb9a8..0000000 --- a/conn/conn_notdarwinlinuxwindows.go +++ /dev/null @@ -1,17 +0,0 @@ -//go:build !darwin && !linux && !windows - -package conn - -import "net/netip" - -// SocketControlMessageBufferSize specifies the buffer size for receiving socket control messages. -const SocketControlMessageBufferSize = 0 - -// ParsePktinfoCmsg parses a single socket control message of type IP_PKTINFO or IPV6_PKTINFO, -// and returns the IP address and index of the network interface the packet was received from, -// or an error. -// -// This function is only implemented for Linux, macOS and Windows. On other platforms, this is a no-op. -func ParsePktinfoCmsg(cmsg []byte) (netip.Addr, uint32, error) { - return netip.Addr{}, 0, nil -} diff --git a/conn/conn_windows.go b/conn/conn_windows.go index a5d7f5b..664490d 100644 --- a/conn/conn_windows.go +++ b/conn/conn_windows.go @@ -2,8 +2,6 @@ package conn import ( "fmt" - "net/netip" - "unsafe" "golang.org/x/sys/windows" ) @@ -71,59 +69,3 @@ func (lso ListenerSocketOptions) buildSetFns() setFuncSlice { appendSetUDPGenericReceiveOffloadFunc(lso.UDPGenericReceiveOffload). appendSetRecvPktinfoFunc(lso.ReceivePacketInfo) } - -// Structure CMSGHDR from ws2def.h -type Cmsghdr struct { - Len uint - Level int32 - Type int32 -} - -// Structure IN_PKTINFO from ws2ipdef.h -type Inet4Pktinfo struct { - Addr [4]byte - Ifindex uint32 -} - -// Structure IN6_PKTINFO from ws2ipdef.h -type Inet6Pktinfo struct { - Addr [16]byte - Ifindex uint32 -} - -const ( - SizeofCmsghdr = unsafe.Sizeof(Cmsghdr{}) - SizeofInet4Pktinfo = unsafe.Sizeof(Inet4Pktinfo{}) - SizeofInet6Pktinfo = unsafe.Sizeof(Inet6Pktinfo{}) -) - -const SizeofPtr = unsafe.Sizeof(uintptr(0)) - -// SocketControlMessageBufferSize specifies the buffer size for receiving socket control messages. -const SocketControlMessageBufferSize = SizeofCmsghdr + (SizeofInet6Pktinfo+SizeofPtr-1) & ^(SizeofPtr-1) - -// ParsePktinfoCmsg parses a single socket control message of type IP_PKTINFO or IPV6_PKTINFO, -// and returns the IP address and index of the network interface the packet was received from, -// or an error. -// -// This function is only implemented for Linux and Windows. On other platforms, this is a no-op. -func ParsePktinfoCmsg(cmsg []byte) (netip.Addr, uint32, error) { - if len(cmsg) < int(SizeofCmsghdr) { - return netip.Addr{}, 0, fmt.Errorf("control message length %d is shorter than cmsghdr length", len(cmsg)) - } - - cmsghdr := (*Cmsghdr)(unsafe.Pointer(&cmsg[0])) - - switch { - case cmsghdr.Level == windows.IPPROTO_IP && cmsghdr.Type == windows.IP_PKTINFO && len(cmsg) >= int(SizeofCmsghdr+SizeofInet4Pktinfo): - pktinfo := (*Inet4Pktinfo)(unsafe.Pointer(&cmsg[SizeofCmsghdr])) - return netip.AddrFrom4(pktinfo.Addr), pktinfo.Ifindex, nil - - case cmsghdr.Level == windows.IPPROTO_IPV6 && cmsghdr.Type == windows.IPV6_PKTINFO && len(cmsg) >= int(SizeofCmsghdr+SizeofInet6Pktinfo): - pktinfo := (*Inet6Pktinfo)(unsafe.Pointer(&cmsg[SizeofCmsghdr])) - return netip.AddrFrom16(pktinfo.Addr), pktinfo.Ifindex, nil - - default: - return netip.Addr{}, 0, fmt.Errorf("unknown control message level %d type %d", cmsghdr.Level, cmsghdr.Type) - } -} diff --git a/service/client.go b/service/client.go index 7721a94..d611a1a 100644 --- a/service/client.go +++ b/service/client.go @@ -301,7 +301,7 @@ func (c *client) recvFromWgConnGeneric(ctx context.Context, wgConn *net.UDPConn) cmsg := cmsgBuf[:cmsgn] if !bytes.Equal(natEntry.clientPktinfoCache, cmsg) { - clientPktinfoAddr, clientPktinfoIfindex, err := conn.ParsePktinfoCmsg(cmsg) + m, err := conn.ParseSocketControlMessage(cmsg) if err != nil { c.logger.Warn("Failed to parse pktinfo control message from wgConn", zap.String("client", c.name), @@ -324,8 +324,8 @@ func (c *client) recvFromWgConnGeneric(ctx context.Context, wgConn *net.UDPConn) zap.String("client", c.name), zap.String("listenAddress", c.wgListenAddress), zap.Stringer("clientAddress", clientAddrPort), - zap.Stringer("clientPktinfoAddr", clientPktinfoAddr), - zap.Uint32("clientPktinfoIfindex", clientPktinfoIfindex), + zap.Stringer("clientPktinfoAddr", m.PktinfoAddr), + zap.Uint32("clientPktinfoIfindex", m.PktinfoIfindex), ) } } diff --git a/service/client_mmsg.go b/service/client_mmsg.go index f96eb28..7dfb2aa 100644 --- a/service/client_mmsg.go +++ b/service/client_mmsg.go @@ -189,7 +189,7 @@ func (c *client) recvFromWgConnRecvmmsg(ctx context.Context, wgConn *conn.MmsgRC cmsg := cmsgvec[i][:msg.Msghdr.Controllen] if !bytes.Equal(natEntry.clientPktinfoCache, cmsg) { - clientPktinfoAddr, clientPktinfoIfindex, err := conn.ParsePktinfoCmsg(cmsg) + m, err := conn.ParseSocketControlMessage(cmsg) if err != nil { c.logger.Warn("Failed to parse pktinfo control message from wgConn", zap.String("client", c.name), @@ -212,8 +212,8 @@ func (c *client) recvFromWgConnRecvmmsg(ctx context.Context, wgConn *conn.MmsgRC zap.String("client", c.name), zap.String("listenAddress", c.wgListenAddress), zap.Stringer("clientAddress", clientAddrPort), - zap.Stringer("clientPktinfoAddr", clientPktinfoAddr), - zap.Uint32("clientPktinfoIfindex", clientPktinfoIfindex), + zap.Stringer("clientPktinfoAddr", m.PktinfoAddr), + zap.Uint32("clientPktinfoIfindex", m.PktinfoIfindex), ) } } diff --git a/service/server.go b/service/server.go index 95abc2b..70eeb05 100644 --- a/service/server.go +++ b/service/server.go @@ -291,7 +291,7 @@ func (s *server) recvFromProxyConnGeneric(ctx context.Context, proxyConn *net.UD cmsg := cmsgBuf[:cmsgn] if !bytes.Equal(natEntry.clientPktinfoCache, cmsg) { - clientPktinfoAddr, clientPktinfoIfindex, err := conn.ParsePktinfoCmsg(cmsg) + m, err := conn.ParseSocketControlMessage(cmsg) if err != nil { s.logger.Warn("Failed to parse pktinfo control message from proxyConn", zap.String("server", s.name), @@ -314,8 +314,8 @@ func (s *server) recvFromProxyConnGeneric(ctx context.Context, proxyConn *net.UD zap.String("server", s.name), zap.String("listenAddress", s.proxyListenAddress), zap.Stringer("clientAddress", clientAddrPort), - zap.Stringer("clientPktinfoAddr", clientPktinfoAddr), - zap.Uint32("clientPktinfoIfindex", clientPktinfoIfindex), + zap.Stringer("clientPktinfoAddr", m.PktinfoAddr), + zap.Uint32("clientPktinfoIfindex", m.PktinfoIfindex), ) } } diff --git a/service/server_mmsg.go b/service/server_mmsg.go index 4e32ff2..fd1506c 100644 --- a/service/server_mmsg.go +++ b/service/server_mmsg.go @@ -187,7 +187,7 @@ func (s *server) recvFromProxyConnRecvmmsg(ctx context.Context, proxyConn *conn. cmsg := cmsgvec[i][:msg.Msghdr.Controllen] if !bytes.Equal(natEntry.clientPktinfoCache, cmsg) { - clientPktinfoAddr, clientPktinfoIfindex, err := conn.ParsePktinfoCmsg(cmsg) + m, err := conn.ParseSocketControlMessage(cmsg) if err != nil { s.logger.Warn("Failed to parse pktinfo control message from proxyConn", zap.String("server", s.name), @@ -210,8 +210,8 @@ func (s *server) recvFromProxyConnRecvmmsg(ctx context.Context, proxyConn *conn. zap.String("server", s.name), zap.String("listenAddress", s.proxyListenAddress), zap.Stringer("clientAddress", clientAddrPort), - zap.Stringer("clientPktinfoAddr", clientPktinfoAddr), - zap.Uint32("clientPktinfoIfindex", clientPktinfoIfindex), + zap.Stringer("clientPktinfoAddr", m.PktinfoAddr), + zap.Uint32("clientPktinfoIfindex", m.PktinfoIfindex), ) } } diff --git a/slicehelper/slicehelper.go b/slicehelper/slicehelper.go new file mode 100644 index 0000000..4feb642 --- /dev/null +++ b/slicehelper/slicehelper.go @@ -0,0 +1,15 @@ +package slicehelper + +// Extend extends the input slice by n elements. head is the full extended +// slice, while tail is the appended part. If the original slice has sufficient +// capacity no allocation is performed. +func Extend[S ~[]E, E any](in S, n int) (head, tail S) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make(S, total) + copy(head, in) + } + tail = head[len(in):] + return +}